Compare commits
60 Commits
main
...
541418fd60
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
541418fd60 | ||
|
|
dff5096164 | ||
|
|
8b580fc732 | ||
|
|
f9123889f2 | ||
|
|
9fd1c961e4 | ||
|
|
4ea15567c2 | ||
|
|
ef16c7e037 | ||
|
|
4b94e03413 | ||
|
|
0a1c7cff5c | ||
|
|
bd5ebdaa43 | ||
|
|
a090579db9 | ||
|
|
5e7db5e271 | ||
|
|
4a3b111ce4 | ||
|
|
fe3e26e21d | ||
|
|
8efe1ae5c5 | ||
|
|
12fac4ea1c | ||
|
|
1bace88f37 | ||
|
|
ba5ca7e0b3 | ||
|
|
e030f3a194 | ||
|
|
43e7e0ba17 | ||
|
|
0ee970d8bd | ||
|
|
ead2060ab3 | ||
|
|
bdc3254ed2 | ||
|
|
685dce2519 | ||
|
|
ec80107128 | ||
|
|
fffca13941 | ||
|
|
760b43cc68 | ||
|
|
3bc48598cd | ||
|
|
704b20cde1 | ||
|
|
d1ae364dbd | ||
|
|
75def0ff38 | ||
|
|
ff629e596d | ||
|
|
592dc6ceb1 | ||
|
|
573c0a3385 | ||
|
|
8aea76d99b | ||
|
|
61096ba190 | ||
|
|
f476545172 | ||
|
|
aae97f6ce9 | ||
|
|
8ce8831315 | ||
|
|
28fb62e5d6 | ||
|
|
42bfdd033c | ||
|
|
945077a453 | ||
|
|
0ce140a210 | ||
|
|
83fe0776eb | ||
|
|
a0019b8b0e | ||
|
|
2a0534ac62 | ||
|
|
3c45fba0f5 | ||
|
|
708925ab41 | ||
|
|
92ad32bb8e | ||
|
|
669d032f96 | ||
| b37c492930 | |||
|
|
46757e848f | ||
|
|
201de84ad0 | ||
|
|
85a5ff9ff0 | ||
|
|
e712e11ea0 | ||
| b552d20a46 | |||
| 21cec260b8 | |||
| 5a98bf2e85 | |||
|
|
f11b31c09c | ||
| 0b18ec353c |
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
/cpp_ext/build/
|
||||
/.cursor/
|
||||
/dist/
|
||||
403
4g_download_manager.py
Normal file
403
4g_download_manager.py
Normal file
@@ -0,0 +1,403 @@
|
||||
import re
|
||||
import hashlib
|
||||
import binascii
|
||||
from maix import time
|
||||
from power import get_bus_voltage, voltage_to_percent
|
||||
from urllib.parse import urlparse
|
||||
from hardware import hardware_manager
|
||||
|
||||
|
||||
class DownloadManager4G:
|
||||
"""4g下载管理器(单例)"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(DownloadManager4G, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 私有状态
|
||||
self.FRAG_SIZE = 1024
|
||||
self.FRAG_DELAY = 10
|
||||
self._initialized = True
|
||||
|
||||
def _log(self, *a):
|
||||
if debug:
|
||||
self.logger.debug(" ".join(str(x) for x in a))
|
||||
|
||||
def _pwr_log(self, prefix=""):
|
||||
"""debug 用:输出电压/电量"""
|
||||
if not debug:
|
||||
return
|
||||
try:
|
||||
v = get_bus_voltage()
|
||||
p = voltage_to_percent(v)
|
||||
self.logger.debug(f"[PWR]{prefix} v={v:.3f}V p={p}%")
|
||||
except Exception as e:
|
||||
try:
|
||||
self.logger.debug(f"[PWR]{prefix} read_failed: {e}")
|
||||
except:
|
||||
pass
|
||||
|
||||
def _clear_http_events(self):
|
||||
if hardware_manager.at_client:
|
||||
while hardware_manager.at_client.pop_http_event() is not None:
|
||||
pass
|
||||
|
||||
def _parse_httpid(self, resp: str):
|
||||
m = re.search(r"\+MHTTPCREATE:\s*(\d+)", resp)
|
||||
return int(m.group(1)) if m else None
|
||||
|
||||
def _get_ip(self, ):
|
||||
r = hardware_manager.at_client.send("AT+CGPADDR=1", "OK", 3000)
|
||||
m = re.search(r'\+CGPADDR:\s*1,"([^"]+)"', r)
|
||||
return m.group(1) if m else ""
|
||||
|
||||
def _ensure_pdp(self, ):
|
||||
ip = self._get_ip()
|
||||
if ip and ip != "0.0.0.0":
|
||||
return True, ip
|
||||
hardware_manager.at_client.send("AT+MIPCALL=1,1", "OK", 15000)
|
||||
for _ in range(10):
|
||||
ip = self._get_ip()
|
||||
if ip and ip != "0.0.0.0":
|
||||
return True, ip
|
||||
time.sleep(1)
|
||||
return False, ip
|
||||
|
||||
def _extract_hdr_fields(self, hdr_text: str):
|
||||
mlen = re.search(r"Content-Length:\s*(\d+)", hdr_text, re.IGNORECASE)
|
||||
clen = int(mlen.group(1)) if mlen else None
|
||||
mmd5 = re.search(r"Content-Md5:\s*([A-Za-z0-9+/=]+)", hdr_text, re.IGNORECASE)
|
||||
md5_b64 = mmd5.group(1).strip() if mmd5 else None
|
||||
return clen, md5_b64
|
||||
|
||||
def _extract_content_range(self, hdr_text: str):
|
||||
m = re.search(r"Content-Range:\s*bytes\s*(\d+)\s*-\s*(\d+)\s*/\s*(\d+)", hdr_text, re.IGNORECASE)
|
||||
if not m:
|
||||
return None, None, None
|
||||
try:
|
||||
return int(m.group(1)), int(m.group(2)), int(m.group(3))
|
||||
except:
|
||||
return None, None, None
|
||||
|
||||
def _hard_reset_http(self, ):
|
||||
"""模块进入"坏状态"时的保守清场"""
|
||||
self._clear_http_events()
|
||||
for i in range(0, 6):
|
||||
try:
|
||||
hardware_manager.at_client.send(f"AT+MHTTPDEL={i}", "OK", 1200)
|
||||
except:
|
||||
pass
|
||||
self._clear_http_events()
|
||||
|
||||
def _create_httpid(self, full_reset=False):
|
||||
self._clear_http_events()
|
||||
if hardware_manager.at_client:
|
||||
hardware_manager.at_client.flush()
|
||||
if full_reset:
|
||||
self._hard_reset_http()
|
||||
resp = hardware_manager.at_client.send(f'AT+MHTTPCREATE="{base_url}"', "OK", 8000)
|
||||
hid = self._parse_httpid(resp)
|
||||
if self._is_https:
|
||||
resp = hardware_manager.at_client.send(f'AT+MHTTPCFG="ssl",{hid},1,1', "OK", 2000)
|
||||
if "ERROR" in resp or "CME ERROR" in resp:
|
||||
self.logger.error(f"MHTTPCFG SSL failed: {resp}")
|
||||
# 尝试https 降级到http
|
||||
downgraded_base_url = base_url.replace("https://", "http://")
|
||||
resp = hardware_manager.at_client.send(f'AT+MHTTPCREATE="{downgraded_base_url}"', "OK", 8000)
|
||||
hid = self._parse_httpid(resp)
|
||||
|
||||
return hid, resp
|
||||
|
||||
def _fetch_range_into_buf(self, start, want_len, out_buf, path, full_reset=False):
|
||||
"""
|
||||
请求 Range [start, start+want_len),写入 out_buf(bytearray,长度=want_len)
|
||||
返回 (ok, msg, total_len, md5_b64, got_len)
|
||||
"""
|
||||
end_incl = start + want_len - 1
|
||||
hid, cresp = self._create_httpid(full_reset=full_reset)
|
||||
if hid is None:
|
||||
return False, f"MHTTPCREATE failed: {cresp}", None, None, 0
|
||||
|
||||
# 降低 URC 压力(分片/延迟)
|
||||
hardware_manager.at_client.send(f'AT+MHTTPCFG="fragment",{hid},{self.FRAG_SIZE},{self.FRAG_DELAY}', "OK", 1500)
|
||||
# 设置 Range header(inclusive)
|
||||
hardware_manager.at_client.send(f'AT+MHTTPCFG="header",{hid},"Range: bytes={start}-{end_incl}"', "OK", 3000)
|
||||
|
||||
req = hardware_manager.at_client.send(f'AT+MHTTPREQUEST={hid},1,0,"{path}"', "OK", 15000)
|
||||
if "ERROR" in req or "CME ERROR" in req:
|
||||
hardware_manager.at_client.send(f"AT+MHTTPDEL={hid}", "OK", 2000)
|
||||
return False, f"MHTTPREQUEST failed: {req}", None, None, 0
|
||||
|
||||
# 等 header + content
|
||||
hdr_text = None
|
||||
hdr_accum = ""
|
||||
code = None
|
||||
resp_total = None
|
||||
total_len = None
|
||||
md5_b64 = None
|
||||
|
||||
got_ranges = set()
|
||||
last_sum = 0
|
||||
t0 = time.ticks_ms()
|
||||
timeout_ms = 9000
|
||||
logged_hdr = False
|
||||
|
||||
while time.ticks_ms() - t0 < timeout_ms:
|
||||
ev = hardware_manager.at_client.pop_http_event() if hardware_manager.at_client else None
|
||||
if not ev:
|
||||
time.sleep_ms(5)
|
||||
continue
|
||||
|
||||
if ev[0] == "header":
|
||||
_, ehid, ecode, ehdr = ev
|
||||
if ehid != hid:
|
||||
continue
|
||||
code = ecode
|
||||
hdr_text = ehdr
|
||||
if ehdr:
|
||||
hdr_accum = (hdr_accum + "\n" + ehdr) if hdr_accum else ehdr
|
||||
|
||||
resp_total_tmp, md5_tmp = self._extract_hdr_fields(hdr_accum)
|
||||
if md5_tmp:
|
||||
md5_b64 = md5_tmp
|
||||
cr_s, cr_e, cr_total = self._extract_content_range(hdr_accum)
|
||||
if cr_total is not None:
|
||||
total_len = cr_total
|
||||
if resp_total_tmp is not None:
|
||||
resp_total = resp_total_tmp
|
||||
elif resp_total is None and (cr_s is not None) and (cr_e is not None) and (cr_e >= cr_s):
|
||||
resp_total = (cr_e - cr_s + 1)
|
||||
if (not logged_hdr) and (resp_total is not None or total_len is not None):
|
||||
self._log(f"[HDR] id={hid} code={code} clen={resp_total} cr={cr_s}-{cr_e}/{cr_total}")
|
||||
logged_hdr = True
|
||||
continue
|
||||
|
||||
if ev[0] == "content":
|
||||
_, ehid, _total, _sum, _cur, payload = ev
|
||||
if ehid != hid:
|
||||
continue
|
||||
if resp_total is None:
|
||||
resp_total = _total
|
||||
if resp_total is None or resp_total <= 0:
|
||||
continue
|
||||
start_rel = _sum - _cur
|
||||
end_rel = _sum
|
||||
if start_rel < 0 or start_rel >= resp_total:
|
||||
continue
|
||||
if end_rel > resp_total:
|
||||
end_rel = resp_total
|
||||
actual_len = min(len(payload), end_rel - start_rel)
|
||||
if actual_len <= 0:
|
||||
continue
|
||||
out_buf[start_rel:start_rel + actual_len] = payload[:actual_len]
|
||||
got_ranges.add((start_rel, start_rel + actual_len))
|
||||
if _sum > last_sum:
|
||||
last_sum = _sum
|
||||
if debug and (last_sum >= resp_total or (last_sum % 512 == 0)):
|
||||
self._log(f"[CHUNK] {start}+{last_sum}/{resp_total}")
|
||||
|
||||
if last_sum >= resp_total:
|
||||
break
|
||||
|
||||
# 清理实例(快路径:只删当前 hid)
|
||||
try:
|
||||
hardware_manager.at_client.send(f"AT+MHTTPDEL={hid}", "OK", 2000)
|
||||
except:
|
||||
pass
|
||||
|
||||
if resp_total is None:
|
||||
return False, "no_header_or_total", total_len, md5_b64, 0
|
||||
|
||||
# 计算实际填充长度
|
||||
merged = sorted(got_ranges)
|
||||
merged2 = []
|
||||
for s, e in merged:
|
||||
if not merged2 or s > merged2[-1][1]:
|
||||
merged2.append((s, e))
|
||||
else:
|
||||
merged2[-1] = (merged2[-1][0], max(merged2[-1][1], e))
|
||||
filled = sum(e - s for s, e in merged2)
|
||||
|
||||
if filled < resp_total:
|
||||
return False, f"incomplete_chunk got={filled} expected={resp_total} code={code}", total_len, md5_b64, filled
|
||||
|
||||
got_len = resp_total
|
||||
return True, "OK", total_len, md5_b64, got_len
|
||||
|
||||
def download_file_via_4g(self, url, filename,
|
||||
total_timeout_ms=600000,
|
||||
retries=3,
|
||||
debug=False):
|
||||
"""
|
||||
ML307R HTTP 下载(更稳的"固定小块 Range 顺序下载",基于main109.py):
|
||||
- 只依赖 +MHTTPURC:"header"/"content"(不依赖 MHTTPREAD/cached)
|
||||
- 每次只请求一个小块 Range(默认 10240B),失败就重试同一块,必要时缩小块大小
|
||||
- 每个 chunk 都重新 MHTTPCREATE/MHTTPREQUEST,避免卡在"206 header 但不吐 content"的坏状态
|
||||
- 使用二进制模式下载,确保文件完整性
|
||||
"""
|
||||
|
||||
|
||||
# 小块策略(与main109.py保持一致)
|
||||
CHUNK_MAX = 10240
|
||||
CHUNK_MIN = 128
|
||||
CHUNK_RETRIES = 12
|
||||
|
||||
|
||||
t_func0 = time.ticks_ms()
|
||||
|
||||
parsed = urlparse(url)
|
||||
host = parsed.hostname
|
||||
path = parsed.path or "/"
|
||||
if parsed.query:
|
||||
path = f"{path}?{parsed.query}"
|
||||
if parsed.fragment:
|
||||
path = f"{path}#{parsed.fragment}"
|
||||
if not host:
|
||||
return False, "bad_url (no host)"
|
||||
|
||||
if isinstance(url, str) and url.startswith("https://static.shelingxingqiu.com/"):
|
||||
base_url = "https://static.shelingxingqiu.com"
|
||||
# TODO:使用https,看看是否能成功
|
||||
self._is_https = True
|
||||
else:
|
||||
base_url = f"http://{host}"
|
||||
self._is_https = False
|
||||
|
||||
|
||||
try:
|
||||
self._begin_ota()
|
||||
except:
|
||||
pass
|
||||
|
||||
from network import network_manager
|
||||
with network_manager.get_uart_lock():
|
||||
try:
|
||||
ok_pdp, ip = self._ensure_pdp()
|
||||
if not ok_pdp:
|
||||
return False, f"PDP not ready (ip={ip})"
|
||||
|
||||
# 先清空旧事件,避免串台
|
||||
self._clear_http_events()
|
||||
|
||||
# 为了支持随机写入,先创建空文件
|
||||
try:
|
||||
with open(filename, "wb") as f:
|
||||
f.write(b"")
|
||||
except Exception as e:
|
||||
return False, f"open_file_failed: {e}"
|
||||
|
||||
total_len = None
|
||||
expect_md5_b64 = None
|
||||
|
||||
offset = 0
|
||||
chunk = CHUNK_MAX
|
||||
t_start = time.ticks_ms()
|
||||
last_progress_ms = t_start
|
||||
STALL_TIMEOUT_MS = 60000
|
||||
last_pwr_ms = t_start
|
||||
self._pwr_log(prefix=" ota_start")
|
||||
bad_http_state = 0
|
||||
|
||||
while True:
|
||||
now = time.ticks_ms()
|
||||
if debug and time.ticks_diff(now, last_pwr_ms) >= 5000:
|
||||
last_pwr_ms = now
|
||||
self._pwr_log(prefix=f" off={offset}/{total_len or '?'}")
|
||||
if time.ticks_diff(now, t_start) > total_timeout_ms:
|
||||
return False, f"timeout overall after {total_timeout_ms}ms offset={offset} total={total_len}"
|
||||
|
||||
if time.ticks_diff(now, last_progress_ms) > STALL_TIMEOUT_MS:
|
||||
return False, f"timeout stalled {STALL_TIMEOUT_MS}ms offset={offset} total={total_len}"
|
||||
|
||||
if total_len is not None and offset >= total_len:
|
||||
break
|
||||
|
||||
want = chunk
|
||||
if total_len is not None:
|
||||
remain = total_len - offset
|
||||
if remain <= 0:
|
||||
break
|
||||
if want > remain:
|
||||
want = remain
|
||||
|
||||
# 本 chunk 的 buffer(长度=want)
|
||||
buf = bytearray(want)
|
||||
|
||||
success = False
|
||||
last_err = "unknown"
|
||||
md5_seen = None
|
||||
got_len = 0
|
||||
for k in range(1, CHUNK_RETRIES + 1):
|
||||
do_full_reset = (bad_http_state >= 2)
|
||||
ok, msg, tlen, md5_b64, got = self._fetch_range_into_buf(offset, want, buf, base_url, path, full_reset=do_full_reset)
|
||||
last_err = msg
|
||||
if tlen is not None and total_len is None:
|
||||
total_len = tlen
|
||||
if md5_b64 and not expect_md5_b64:
|
||||
expect_md5_b64 = md5_b64
|
||||
if ok:
|
||||
success = True
|
||||
got_len = got
|
||||
bad_http_state = 0
|
||||
break
|
||||
|
||||
try:
|
||||
if ("no_header_or_total" in msg) or ("MHTTPREQUEST failed" in msg) or (
|
||||
"MHTTPCREATE failed" in msg):
|
||||
bad_http_state += 1
|
||||
else:
|
||||
bad_http_state = max(0, bad_http_state - 1)
|
||||
except:
|
||||
pass
|
||||
|
||||
if chunk > CHUNK_MIN:
|
||||
chunk = max(CHUNK_MIN, chunk // 2)
|
||||
want = min(chunk, want)
|
||||
buf = bytearray(want)
|
||||
self._log(f"[RETRY] off={offset} want={want} try={k} err={msg}")
|
||||
self._pwr_log(prefix=f" retry{k} off={offset}")
|
||||
time.sleep_ms(120)
|
||||
|
||||
if not success:
|
||||
return False, f"chunk_failed off={offset} want={want} err={last_err} total={total_len}"
|
||||
|
||||
# 写入文件(二进制模式)
|
||||
try:
|
||||
with open(filename, "r+b") as f:
|
||||
f.seek(offset)
|
||||
f.write(bytes(buf))
|
||||
except Exception as e:
|
||||
return False, f"write_failed off={offset}: {e}"
|
||||
|
||||
offset += len(buf)
|
||||
last_progress_ms = time.ticks_ms()
|
||||
chunk = CHUNK_MAX
|
||||
if debug:
|
||||
self._log(f"[OK] offset={offset}/{total_len or '?'}")
|
||||
|
||||
# MD5 校验
|
||||
if expect_md5_b64 and hashlib is not None:
|
||||
try:
|
||||
with open(filename, "rb") as f:
|
||||
data = f.read()
|
||||
digest = hashlib.md5(data).digest()
|
||||
got_b64 = binascii.b2a_base64(digest).decode().strip()
|
||||
if got_b64 != expect_md5_b64:
|
||||
return False, f"md5_mismatch got={got_b64} expected={expect_md5_b64}"
|
||||
self.logger.debug(f"[4G-DL] MD5 verified: {got_b64}")
|
||||
except Exception as e:
|
||||
return False, f"md5_check_failed: {e}"
|
||||
|
||||
t_cost = time.ticks_diff(time.ticks_ms(), t_func0)
|
||||
self.logger.info(f"[4G-DL] download complete: size={offset} ip={ip} cost_ms={t_cost}")
|
||||
return True, f"OK size={offset} ip={ip} cost_ms={t_cost}"
|
||||
|
||||
finally:
|
||||
self._end_ota()
|
||||
450
4g_upload_manager.py
Normal file
450
4g_upload_manager.py
Normal file
@@ -0,0 +1,450 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
4G Image Upload Manager
|
||||
Uploads images to Qiniu cloud via ML307R 4G module TCP socket (MIPOPEN + MIPSEND).
|
||||
|
||||
AT Command Sequence (ML307R TCP socket POST):
|
||||
AT+MIPCALL=1,1 // Ensure PDP context active
|
||||
AT+MIPCLOSE=<id> // Close old socket (ignore error)
|
||||
AT+MIPOPEN=<id>,"TCP","<host>",80 // Open TCP socket
|
||||
// Wait for +MIPOPEN: <id>,0 (success)
|
||||
AT+MIPSEND=<id>,<len> // Send data
|
||||
// Wait for ">" prompt, then write raw bytes
|
||||
// Repeat MIPSEND for all chunks
|
||||
// Wait for +MIPURC: "rtcp" response
|
||||
AT+MIPCLOSE=<id> // Close socket
|
||||
"""
|
||||
|
||||
import re
|
||||
import os
|
||||
import json
|
||||
from maix import time
|
||||
from urllib.parse import urlparse
|
||||
from logger_manager import logger_manager
|
||||
from hardware import hardware_manager
|
||||
|
||||
# Multipart form boundary (simple alphanumeric to avoid AT command parser issues)
|
||||
BOUNDARY = "QiniuFormBoundary" + hex(int(time.time()))[2:]
|
||||
# Chunk size for MIPSEND (max 1024 to avoid AT line buffer limits)
|
||||
SEND_CHUNK = 1024
|
||||
# Socket ID for upload (dedicated to avoid conflict with main app TCP)
|
||||
UPLOAD_SOCK_ID = 3
|
||||
|
||||
|
||||
class FourGUploadManager:
|
||||
"""4G image upload manager using ML307R TCP socket (MIPOPEN + MIPSEND)"""
|
||||
|
||||
def __init__(self, at_client):
|
||||
"""Initialize with AT client instance"""
|
||||
self.at = at_client
|
||||
self.logger = logger_manager.logger
|
||||
|
||||
# ------------------------------------------------------------------ logging
|
||||
def _log(self, msg):
|
||||
try:
|
||||
self.logger.debug("[4G-UL] " + msg)
|
||||
except Exception:
|
||||
print("[4G-UL] " + msg)
|
||||
|
||||
def _log_info(self, msg):
|
||||
try:
|
||||
self.logger.info("[4G-UL] " + msg)
|
||||
except Exception:
|
||||
print("[4G-UL] " + msg)
|
||||
|
||||
def _log_error(self, msg):
|
||||
try:
|
||||
self.logger.error("[4G-UL] " + msg)
|
||||
except Exception:
|
||||
print("[4G-UL] " + msg)
|
||||
|
||||
# --------------------------------------------------------------- helpers
|
||||
def _ensure_pdp(self):
|
||||
"""Ensure PDP context is active; returns (ok, ip)"""
|
||||
r = self.at.send("AT+CGPADDR=1", "OK", 3000)
|
||||
m = re.search(r'\+CGPADDR:\s*1,"([^"]+)"', r)
|
||||
ip = m.group(1) if m else ""
|
||||
if ip and ip != "0.0.0.0":
|
||||
return True, ip
|
||||
self.at.send("AT+MIPCALL=1,1", "OK", 15000)
|
||||
for _ in range(10):
|
||||
r = self.at.send("AT+CGPADDR=1", "OK", 3000)
|
||||
m = re.search(r'\+CGPADDR:\s*1,"([^"]+)"', r)
|
||||
ip = m.group(1) if m else ""
|
||||
if ip and ip != "0.0.0.0":
|
||||
return True, ip
|
||||
time.sleep(1)
|
||||
return False, ip
|
||||
|
||||
def _is_error(self, resp):
|
||||
"""Check AT response for any error indicators"""
|
||||
return "ERROR" in resp or "CME ERROR" in resp
|
||||
|
||||
# --------------------------------------------------------- multipart body
|
||||
def _build_multipart_body(self, image_path, upload_token, key):
|
||||
"""
|
||||
Build multipart/form-data body as bytes for Qiniu upload.
|
||||
|
||||
Fields:
|
||||
- token : Qiniu upload token
|
||||
- key : object key in bucket
|
||||
- file : binary image data
|
||||
"""
|
||||
boundary = BOUNDARY.encode()
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
file_data = f.read()
|
||||
|
||||
filename = os.path.basename(image_path)
|
||||
ext = os.path.splitext(image_path)[1].lower()
|
||||
ct_map = {
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".bmp": "image/bmp",
|
||||
".webp": "image/webp",
|
||||
}
|
||||
content_type = ct_map.get(ext, "application/octet-stream")
|
||||
|
||||
body = bytearray()
|
||||
|
||||
# -- token field --
|
||||
body += b"--" + boundary + b"\r\n"
|
||||
body += b'Content-Disposition: form-data; name="token"\r\n'
|
||||
body += b"\r\n"
|
||||
body += upload_token.encode("utf-8") + b"\r\n"
|
||||
|
||||
# -- key field --
|
||||
body += b"--" + boundary + b"\r\n"
|
||||
body += b'Content-Disposition: form-data; name="key"\r\n'
|
||||
body += b"\r\n"
|
||||
body += key.encode("utf-8") + b"\r\n"
|
||||
|
||||
# -- file field --
|
||||
body += b"--" + boundary + b"\r\n"
|
||||
body += (
|
||||
b'Content-Disposition: form-data; name="file"; filename="'
|
||||
+ filename.encode("utf-8")
|
||||
+ b'"\r\n'
|
||||
)
|
||||
body += b"Content-Type: " + content_type.encode("utf-8") + b"\r\n"
|
||||
body += b"\r\n"
|
||||
body += file_data + b"\r\n"
|
||||
|
||||
# -- closing boundary --
|
||||
body += b"--" + boundary + b"--\r\n"
|
||||
|
||||
return bytes(body)
|
||||
|
||||
# --------------------------------------------------- TCP socket helpers
|
||||
def _close_socket(self, sock_id):
|
||||
"""Close socket, ignore CME ERROR 55 (already closed)"""
|
||||
try:
|
||||
resp = self.at.send("AT+MIPCLOSE=" + str(sock_id), "OK", 5000)
|
||||
self._log("socket " + str(sock_id) + " closed: " + resp)
|
||||
except Exception as e:
|
||||
# Ignore CME ERROR 55 (socket not open)
|
||||
self._log("socket close (may already be closed): " + str(e))
|
||||
|
||||
def _open_socket(self, sock_id, host, port):
|
||||
"""
|
||||
Open TCP socket to host:port.
|
||||
Returns (success, error_msg)
|
||||
"""
|
||||
cmd = 'AT+MIPOPEN=' + str(sock_id) + ',"TCP","' + host + '",' + str(port)
|
||||
resp = self.at.send(cmd, "OK", 15000)
|
||||
|
||||
if self._is_error(resp):
|
||||
return False, "MIPOPEN failed: " + resp
|
||||
|
||||
# Wait for +MIPOPEN: <id>,0 (success) or +MIPOPEN: <id>,<error_code>
|
||||
# The URC may come in the same response or separately
|
||||
mipopen_pattern = r"\+MIPOPEN:\s*" + str(sock_id) + r",(\d+)"
|
||||
m = re.search(mipopen_pattern, resp)
|
||||
|
||||
if m:
|
||||
result_code = int(m.group(1))
|
||||
if result_code == 0:
|
||||
return True, ""
|
||||
else:
|
||||
return False, "MIPOPEN error code: " + str(result_code)
|
||||
|
||||
# If not in initial response, wait for URC
|
||||
try:
|
||||
urc_resp = self.at.send("", "+MIPOPEN:", 15000)
|
||||
m = re.search(mipopen_pattern, urc_resp)
|
||||
if m:
|
||||
result_code = int(m.group(1))
|
||||
if result_code == 0:
|
||||
return True, ""
|
||||
else:
|
||||
return False, "MIPOPEN error code: " + str(result_code)
|
||||
except Exception as e:
|
||||
return False, "MIPOPEN URC timeout: " + str(e)
|
||||
|
||||
return False, "MIPOPEN no response"
|
||||
|
||||
def _send_chunk(self, sock_id, chunk):
|
||||
"""
|
||||
Send a single chunk via MIPSEND.
|
||||
Thread safety is provided by the outer network_manager.get_uart_lock().
|
||||
NOTE: Do NOT add self.at._cmd_lock here — self.at.send() already
|
||||
acquires it internally and threading.Lock is not reentrant.
|
||||
Returns (success, error_msg)
|
||||
"""
|
||||
chunk_len = len(chunk)
|
||||
|
||||
# Step 1: Send AT+MIPSEND command and wait for ">" prompt
|
||||
cmd = "AT+MIPSEND=" + str(sock_id) + "," + str(chunk_len)
|
||||
try:
|
||||
resp = self.at.send(cmd, ">", 3000)
|
||||
if ">" not in resp:
|
||||
return False, "MIPSEND no > prompt: " + resp
|
||||
except Exception as e:
|
||||
return False, "MIPSEND > prompt error: " + str(e)
|
||||
|
||||
# Step 2: Write raw binary bytes directly to UART
|
||||
# Must be done immediately after ">" prompt, no lock re-acquisition
|
||||
try:
|
||||
self.at.uart.write(chunk)
|
||||
except Exception as e:
|
||||
return False, "MIPSEND write error: " + str(e)
|
||||
|
||||
# Step 3: Wait for OK or SEND OK confirmation
|
||||
try:
|
||||
confirm_resp = self.at.send("", "OK", 8000)
|
||||
if self._is_error(confirm_resp):
|
||||
return False, "MIPSEND confirmation error: " + confirm_resp
|
||||
except Exception as e:
|
||||
return False, "MIPSEND confirmation timeout: " + str(e)
|
||||
|
||||
return True, ""
|
||||
|
||||
def _send_data(self, sock_id, data):
|
||||
"""
|
||||
Send data in chunks via MIPSEND.
|
||||
Returns (success, error_msg)
|
||||
"""
|
||||
total_len = len(data)
|
||||
offset = 0
|
||||
chunk_num = 0
|
||||
|
||||
while offset < total_len:
|
||||
end = min(offset + SEND_CHUNK, total_len)
|
||||
chunk = data[offset:end]
|
||||
|
||||
ok, err = self._send_chunk(sock_id, chunk)
|
||||
if not ok:
|
||||
return False, "Chunk " + str(chunk_num) + " failed: " + err
|
||||
|
||||
chunk_num += 1
|
||||
offset = end
|
||||
|
||||
if chunk_num % 10 == 0 or offset >= total_len:
|
||||
self._log(
|
||||
"send progress: "
|
||||
+ str(offset) + "/" + str(total_len)
|
||||
+ " bytes (" + str(chunk_num) + " chunks)"
|
||||
)
|
||||
|
||||
self._log("all data sent: " + str(chunk_num) + " chunks, " + str(total_len) + " bytes")
|
||||
return True, ""
|
||||
|
||||
def _wait_for_response(self, sock_id, timeout_ms=30000):
|
||||
"""
|
||||
Wait for +MIPURC: "rtcp" response.
|
||||
Returns (success, status_code, body, error_msg)
|
||||
"""
|
||||
pattern = r'\+MIPURC:\s*"rtcp",\s*' + str(sock_id) + r',\s*(\d+),'
|
||||
t0 = time.ticks_ms()
|
||||
|
||||
while time.ticks_diff(time.ticks_ms(), t0) < timeout_ms:
|
||||
try:
|
||||
# Try to get response with short timeout
|
||||
resp = self.at.send("", "+MIPURC:", 1000)
|
||||
m = re.search(pattern, resp)
|
||||
if m:
|
||||
data_len = int(m.group(1))
|
||||
# Extract HTTP response data after the URC header
|
||||
# Format: +MIPURC: "rtcp",<sock_id>,<len>,<data>
|
||||
urc_end = resp.find("+MIPURC:")
|
||||
if urc_end >= 0:
|
||||
# Find the data after the length field
|
||||
match_end = m.end()
|
||||
http_data = resp[match_end:match_end + data_len]
|
||||
|
||||
# Parse HTTP status line
|
||||
status_match = re.search(r"HTTP/\d\.\d\s+(\d+)", http_data)
|
||||
status_code = int(status_match.group(1)) if status_match else None
|
||||
|
||||
# Extract body (after headers)
|
||||
header_end = http_data.find("\r\n\r\n")
|
||||
if header_end >= 0:
|
||||
body = http_data[header_end + 4:]
|
||||
else:
|
||||
body = http_data
|
||||
|
||||
return True, status_code, body, ""
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time.sleep_ms(100)
|
||||
|
||||
return False, None, "", "Response timeout"
|
||||
|
||||
def _build_http_request(self, host, body_bytes):
|
||||
"""
|
||||
Build full HTTP POST request as bytes.
|
||||
"""
|
||||
headers = (
|
||||
"POST / HTTP/1.1\r\n"
|
||||
"Host: " + host + "\r\n"
|
||||
"Content-Type: multipart/form-data; boundary=" + BOUNDARY + "\r\n"
|
||||
"Content-Length: " + str(len(body_bytes)) + "\r\n"
|
||||
"Connection: close\r\n"
|
||||
"\r\n"
|
||||
)
|
||||
return headers.encode("utf-8") + body_bytes
|
||||
|
||||
# ============================================================ public API
|
||||
def upload_file(self, file_path, upload_url, upload_token, key):
|
||||
"""Generic file upload to Qiniu cloud via 4G TCP socket POST.
|
||||
|
||||
Args:
|
||||
file_path: Local path to any file
|
||||
upload_url: Qiniu upload URL
|
||||
upload_token: Qiniu upload token
|
||||
key: File key in Qiniu bucket
|
||||
|
||||
Returns:
|
||||
dict with 'success' bool and 'key'/'error' fields
|
||||
"""
|
||||
return self.upload_image(file_path, upload_url, upload_token, key)
|
||||
|
||||
def upload_image(self, image_path, upload_url, upload_token, key):
|
||||
"""
|
||||
Upload image to Qiniu cloud via 4G TCP socket POST.
|
||||
|
||||
Args:
|
||||
image_path: Local path to image file
|
||||
upload_url: Qiniu upload URL (e.g., "https://upload.qiniup.com")
|
||||
upload_token: Qiniu upload token
|
||||
key: File key in Qiniu (e.g., "shootPic/device01/shoot01.png")
|
||||
|
||||
Returns:
|
||||
dict with 'success' bool and 'key'/'error' fields
|
||||
"""
|
||||
if not self.at:
|
||||
return {"success": False, "error": "AT client not available"}
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
return {"success": False, "error": "Image file not found: " + image_path}
|
||||
|
||||
# Force HTTP for 4G module (extract hostname, use port 80)
|
||||
parsed = urlparse(upload_url)
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
return {"success": False, "error": "Invalid upload URL: " + upload_url}
|
||||
|
||||
if upload_url.lower().startswith("https://"):
|
||||
self._log_info("Converted HTTPS->HTTP for 4G module")
|
||||
|
||||
file_size = os.path.getsize(image_path)
|
||||
self._log_info(
|
||||
"upload: " + image_path + " (" + str(file_size) + "B) -> "
|
||||
+ host + " key=" + key
|
||||
)
|
||||
|
||||
from network import network_manager
|
||||
with network_manager.get_uart_lock():
|
||||
try:
|
||||
# ---- Step 1: Ensure PDP context ----
|
||||
ok_pdp, ip = self._ensure_pdp()
|
||||
if not ok_pdp:
|
||||
return {"success": False, "error": "PDP not ready (ip=" + str(ip) + ")"}
|
||||
|
||||
# ---- Step 2: Close old socket ----
|
||||
self._close_socket(UPLOAD_SOCK_ID)
|
||||
|
||||
# ---- Step 3: Open TCP socket ----
|
||||
ok, err = self._open_socket(UPLOAD_SOCK_ID, host, 80)
|
||||
if not ok:
|
||||
return {"success": False, "error": "Socket open failed: " + err}
|
||||
|
||||
try:
|
||||
# ---- Step 4: Build multipart body and HTTP request ----
|
||||
body = self._build_multipart_body(image_path, upload_token, key)
|
||||
http_request = self._build_http_request(host, body)
|
||||
self._log("HTTP request size: " + str(len(http_request)) + " bytes")
|
||||
|
||||
# ---- Step 5: Send data via MIPSEND ----
|
||||
ok, err = self._send_data(UPLOAD_SOCK_ID, http_request)
|
||||
if not ok:
|
||||
return {"success": False, "error": "Send failed: " + err}
|
||||
|
||||
# ---- Step 6: Wait for response ----
|
||||
ok, status_code, resp_body, err = self._wait_for_response(UPLOAD_SOCK_ID)
|
||||
if not ok:
|
||||
return {"success": False, "error": "Response error: " + err}
|
||||
|
||||
# ---- Step 7: Parse response ----
|
||||
if status_code is None:
|
||||
return {"success": False, "error": "No HTTP status in response"}
|
||||
|
||||
if 200 <= status_code < 300:
|
||||
try:
|
||||
resp_json = json.loads(resp_body)
|
||||
resp_key = resp_json.get("key", key)
|
||||
self._log_info("upload success: key=" + resp_key + " code=" + str(status_code))
|
||||
return {"success": True, "key": resp_key}
|
||||
except Exception as e:
|
||||
self._log_error("response parse error: " + str(e))
|
||||
return {
|
||||
"success": True,
|
||||
"key": key,
|
||||
"raw": resp_body,
|
||||
}
|
||||
else:
|
||||
self._log_error(
|
||||
"HTTP error: code=" + str(status_code) + " body=" + resp_body[:200]
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"error": "HTTP " + str(status_code),
|
||||
"response": resp_body,
|
||||
}
|
||||
|
||||
finally:
|
||||
# ---- Step 8: Always close socket ----
|
||||
self._close_socket(UPLOAD_SOCK_ID)
|
||||
|
||||
except Exception as e:
|
||||
self._log_error("upload exception: " + str(e))
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
# ====================================================================== demo
|
||||
if __name__ == "__main__":
|
||||
# Demo usage — requires actual ML307R 4G module hardware to run.
|
||||
print("FourGUploadManager - requires ML307R 4G module hardware")
|
||||
print()
|
||||
print("Usage:")
|
||||
print(" from hardware import hardware_manager")
|
||||
print(" from at_client import ATClient")
|
||||
print(" from maix import uart")
|
||||
print()
|
||||
print(" # Initialize UART and AT client (normally done in hardware init)")
|
||||
print(" uart4g = uart.UART('/dev/ttyS1', 115200, ...)")
|
||||
print(" at_client = ATClient(uart4g)")
|
||||
print(" at_client.start()")
|
||||
print()
|
||||
print(" # Upload image to Qiniu")
|
||||
print(" uploader = FourGUploadManager(at_client)")
|
||||
print(" result = uploader.upload_image(")
|
||||
print(" image_path='/maixapp/apps/t11/shoot.png',")
|
||||
print(" upload_url='https://upload.qiniup.com',")
|
||||
print(" upload_token='<qiniu_upload_token>',")
|
||||
print(" key='shootPic/device01/shoot01.png'")
|
||||
print(" )")
|
||||
print(" print('Upload result:', result)")
|
||||
79
S99archery
Normal file
79
S99archery
Normal file
@@ -0,0 +1,79 @@
|
||||
#!/bin/sh
|
||||
# /etc/init.d/S99archery
|
||||
# 系统启动时处理致命错误恢复(仅处理无法启动的情况)
|
||||
# 注意:应用的启动由系统自动启动机制处理(通过 auto_start.txt)
|
||||
# 功能:
|
||||
# 1. 处理致命错误(无法启动)- 恢复 main.py
|
||||
# 2. 如果重启次数超过阈值,恢复 main.py 并重启系统
|
||||
|
||||
APP_DIR="/maixapp/apps/t11"
|
||||
MAIN_PY="$APP_DIR/main.py"
|
||||
PENDING_FILE="$APP_DIR/ota_pending.json"
|
||||
BACKUP_BASE="$APP_DIR/backups"
|
||||
|
||||
# 进入应用目录
|
||||
cd "$APP_DIR" || exit 0
|
||||
|
||||
# 检查 pending 文件,如果存在且超过重启次数,恢复 main.py(处理致命错误)
|
||||
if [ -f "$PENDING_FILE" ]; then
|
||||
echo "[S99] 检测到 ota_pending.json,检查重启计数..."
|
||||
|
||||
# 尝试从JSON中提取重启计数(使用grep简单提取)
|
||||
RESTART_COUNT=$(cat "$PENDING_FILE" 2>/dev/null | grep -o '"restart_count":[0-9]*' | grep -o '[0-9]*' || echo "0")
|
||||
MAX_RESTARTS=$(cat "$PENDING_FILE" 2>/dev/null | grep -o '"max_restarts":[0-9]*' | grep -o '[0-9]*' || echo "3")
|
||||
|
||||
if [ -n "$RESTART_COUNT" ] && [ "$RESTART_COUNT" -ge "$MAX_RESTARTS" ]; then
|
||||
echo "[S99] 检测到重启次数 ($RESTART_COUNT) 超过阈值 ($MAX_RESTARTS),恢复 main.py..."
|
||||
|
||||
# 尝试从JSON中提取备份目录
|
||||
BACKUP_DIR=$(cat "$PENDING_FILE" 2>/dev/null | grep -o '"backup_dir":"[^"]*"' | grep -o '/[^"]*' || echo "")
|
||||
|
||||
if [ -n "$BACKUP_DIR" ] && [ -f "$BACKUP_DIR/main.py" ]; then
|
||||
# 使用指定的备份目录
|
||||
echo "[S99] 从备份目录恢复: $BACKUP_DIR/main.py"
|
||||
cp "$BACKUP_DIR/main.py" "$MAIN_PY" 2>/dev/null && echo "[S99] 已恢复 main.py"
|
||||
else
|
||||
# 查找最新的备份目录
|
||||
LATEST_BACKUP=$(ls -dt "$BACKUP_BASE"/backup_* 2>/dev/null | head -1)
|
||||
if [ -n "$LATEST_BACKUP" ] && [ -f "$LATEST_BACKUP/main.py" ]; then
|
||||
echo "[S99] 从最新备份恢复: $LATEST_BACKUP/main.py"
|
||||
cp "$LATEST_BACKUP/main.py" "$MAIN_PY" 2>/dev/null && echo "[S99] 已恢复 main.py"
|
||||
else
|
||||
# 如果没有备份目录,尝试使用 main.py.bak
|
||||
if [ -f "$APP_DIR/main.py.bak" ]; then
|
||||
echo "[S99] 从 main.py.bak 恢复"
|
||||
cp "$APP_DIR/main.py.bak" "$MAIN_PY" 2>/dev/null && echo "[S99] 已恢复 main.py"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
# 恢复后重置重启计数,避免循环恢复
|
||||
# 注意:不在这里删除 pending 文件,让 main.py 在心跳成功后删除
|
||||
# 但是重置重启计数,以便恢复后的版本可以重新开始计数
|
||||
python3 -c "
|
||||
import json, os
|
||||
try:
|
||||
pending_path = '$PENDING_FILE'
|
||||
if os.path.exists(pending_path):
|
||||
with open(pending_path, 'r', encoding='utf-8') as f:
|
||||
d = json.load(f)
|
||||
d['restart_count'] = 0 # 重置重启计数
|
||||
with open(pending_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(d, f)
|
||||
print('[S99] 已重置重启计数为 0')
|
||||
except Exception as e:
|
||||
print(f'[S99] 重置重启计数失败: {e}')
|
||||
" 2>/dev/null || echo "[S99] 无法重置重启计数(可能需要Python支持)"
|
||||
|
||||
echo "[S99] 已恢复 main.py,重启系统..."
|
||||
echo "[S99] 注意:pending 文件将在心跳成功后由 main.py 删除"
|
||||
sleep 2
|
||||
reboot
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
# 不启动应用,让系统自动启动机制处理
|
||||
# 这个脚本只负责处理致命错误恢复
|
||||
exit 0
|
||||
|
||||
11
adc.py
11
adc.py
@@ -4,14 +4,13 @@ from maix import time
|
||||
a = adc.ADC(0, adc.RES_BIT_12)
|
||||
|
||||
while True:
|
||||
raw_data = a.read()
|
||||
print(f"ADC raw data:{raw_data}")
|
||||
# raw_data = a.read()
|
||||
# print(f"ADC raw data:{raw_data}")
|
||||
# if raw_data > 2450:
|
||||
# print(f"ADC raw data:{raw_data}")
|
||||
# elif raw_data < 2000:
|
||||
# print(f"ADC raw data:{raw_data}")
|
||||
# time.sleep_ms(50)
|
||||
time.sleep_ms(1)
|
||||
|
||||
# vol = a.read_vol()
|
||||
|
||||
# print(f"ADC vol:{vol}")
|
||||
vol = int(a.read_vol() * 10) / 10
|
||||
print(f"ADC vol:{vol:.1f}, {time.time():.4f}")
|
||||
|
||||
32
app.yaml
32
app.yaml
@@ -1,8 +1,38 @@
|
||||
id: t11
|
||||
name: t11
|
||||
version: 1.0.2
|
||||
version: 1.2.12
|
||||
author: t11
|
||||
icon: ''
|
||||
desc: t11
|
||||
files:
|
||||
- 4g_download_manager.py
|
||||
- 4g_upload_manager.py
|
||||
- app.yaml
|
||||
- archery_netcore.cpython-311-riscv64-linux-gnu.so
|
||||
- at_client.py
|
||||
- camera_manager.py
|
||||
- cameraParameters.xml
|
||||
- config.py
|
||||
- hardware.py
|
||||
- laser_manager.py
|
||||
- logger_manager.py
|
||||
- main.py
|
||||
- model_270139.cvimodel
|
||||
- model_270139.mud
|
||||
- model_270820.cvimodel
|
||||
- model_270820.mud
|
||||
- network.py
|
||||
- ota_manager.py
|
||||
- power.py
|
||||
- server.pem
|
||||
- shoot_manager.py
|
||||
- shot_id_generator.py
|
||||
- target_roi_yolo.py
|
||||
- time_sync.py
|
||||
- triangle_positions.json
|
||||
- triangle_target.py
|
||||
- version.py
|
||||
- vision.py
|
||||
- wifi_config_httpd.py
|
||||
- wifi.py
|
||||
- wpa_supplicant_conf.py
|
||||
|
||||
BIN
archery_netcore.cpython-311-riscv64-linux-gnu.so
Normal file
BIN
archery_netcore.cpython-311-riscv64-linux-gnu.so
Normal file
Binary file not shown.
420
aruco_detector.py
Normal file
420
aruco_detector.py
Normal file
@@ -0,0 +1,420 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
ArUco标记检测模块
|
||||
提供基于ArUco标记的靶心标定和激光点定位功能
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
import config
|
||||
from logger_manager import logger_manager
|
||||
|
||||
|
||||
class ArUcoDetector:
|
||||
"""ArUco标记检测器"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logger_manager.logger
|
||||
# 创建ArUco字典和检测器参数
|
||||
self.aruco_dict = cv2.aruco.getPredefinedDictionary(config.ARUCO_DICT_TYPE)
|
||||
self.detector_params = cv2.aruco.DetectorParameters()
|
||||
|
||||
# 设置检测参数
|
||||
self.detector_params.minMarkerPerimeterRate = config.ARUCO_MIN_MARKER_PERIMETER_RATE
|
||||
self.detector_params.cornerRefinementMethod = config.ARUCO_CORNER_REFINEMENT_METHOD
|
||||
|
||||
# 创建检测器
|
||||
self.detector = cv2.aruco.ArucoDetector(self.aruco_dict, self.detector_params)
|
||||
|
||||
# 预定义靶纸上的标记位置(物理坐标,毫米)
|
||||
self.marker_positions_mm = config.ARUCO_MARKER_POSITIONS_MM
|
||||
self.marker_ids = config.ARUCO_MARKER_IDS
|
||||
self.marker_size_mm = config.ARUCO_MARKER_SIZE_MM
|
||||
self.target_paper_size_mm = config.TARGET_PAPER_SIZE_MM
|
||||
|
||||
# 靶心偏移(相对于靶纸中心)
|
||||
self.target_center_offset_mm = config.TARGET_CENTER_OFFSET_MM
|
||||
|
||||
if self.logger:
|
||||
self.logger.info(f"[ARUCO] ArUco检测器初始化完成,字典类型: {config.ARUCO_DICT_TYPE}")
|
||||
|
||||
def detect_markers(self, frame):
|
||||
"""
|
||||
检测图像中的ArUco标记
|
||||
|
||||
Args:
|
||||
frame: MaixPy图像帧对象
|
||||
|
||||
Returns:
|
||||
(corners, ids, rejected) - 检测到的标记角点、ID列表、被拒绝的候选
|
||||
如果检测失败返回 (None, None, None)
|
||||
"""
|
||||
try:
|
||||
# 转换为OpenCV格式
|
||||
from maix import image
|
||||
img_cv = image.image2cv(frame, False, False)
|
||||
|
||||
# 转换为灰度图(ArUco检测需要)
|
||||
if len(img_cv.shape) == 3:
|
||||
gray = cv2.cvtColor(img_cv, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
gray = img_cv
|
||||
|
||||
# 检测标记
|
||||
corners, ids, rejected = self.detector.detectMarkers(gray)
|
||||
|
||||
if self.logger and ids is not None:
|
||||
self.logger.debug(f"[ARUCO] 检测到 {len(ids)} 个标记: {ids.flatten().tolist()}")
|
||||
|
||||
return corners, ids, rejected
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"[ARUCO] 标记检测失败: {e}")
|
||||
return None, None, None
|
||||
|
||||
def get_target_center_from_markers(self, corners, ids):
|
||||
"""
|
||||
从检测到的ArUco标记计算靶心位置
|
||||
|
||||
Args:
|
||||
corners: 标记角点列表
|
||||
ids: 标记ID列表
|
||||
|
||||
Returns:
|
||||
(center_x, center_y, radius, ellipse_params) 或 (None, None, None, None)
|
||||
center_x, center_y: 靶心像素坐标
|
||||
radius: 估计的靶心半径(像素)
|
||||
ellipse_params: 椭圆参数用于透视校正
|
||||
"""
|
||||
if ids is None or len(ids) < 3:
|
||||
if self.logger:
|
||||
self.logger.debug(f"[ARUCO] 检测到的标记数量不足: {len(ids) if ids is not None else 0} < 3")
|
||||
return None, None, None, None
|
||||
|
||||
try:
|
||||
# 将ID转换为列表便于查找
|
||||
detected_ids = ids.flatten().tolist()
|
||||
|
||||
# 收集检测到的标记中心点和对应的物理坐标
|
||||
image_points = [] # 图像坐标 (像素)
|
||||
object_points = [] # 物理坐标 (毫米)
|
||||
marker_centers = {} # 存储每个标记的中心
|
||||
|
||||
for i, marker_id in enumerate(detected_ids):
|
||||
if marker_id not in self.marker_ids:
|
||||
continue
|
||||
|
||||
# 计算标记中心(四个角的平均值)
|
||||
corner = corners[i][0] # shape: (4, 2)
|
||||
center_x = np.mean(corner[:, 0])
|
||||
center_y = np.mean(corner[:, 1])
|
||||
marker_centers[marker_id] = (center_x, center_y)
|
||||
|
||||
# 添加到点列表
|
||||
image_points.append([center_x, center_y])
|
||||
object_points.append(self.marker_positions_mm[marker_id])
|
||||
|
||||
if len(image_points) < 3:
|
||||
if self.logger:
|
||||
self.logger.debug(f"[ARUCO] 有效标记数量不足: {len(image_points)} < 3")
|
||||
return None, None, None, None
|
||||
|
||||
# 转换为numpy数组
|
||||
image_points = np.array(image_points, dtype=np.float32)
|
||||
object_points = np.array(object_points, dtype=np.float32)
|
||||
|
||||
# 计算单应性矩阵(Homography)
|
||||
# 这建立了物理坐标到图像坐标的映射
|
||||
H, status = cv2.findHomography(object_points, image_points, cv2.RANSAC, 5.0)
|
||||
|
||||
if H is None:
|
||||
if self.logger:
|
||||
self.logger.warning("[ARUCO] 无法计算单应性矩阵")
|
||||
return None, None, None, None
|
||||
|
||||
# 计算靶心在图像中的位置
|
||||
# 靶心物理坐标 = 靶纸中心 + 偏移
|
||||
target_center_mm = np.array([[self.target_center_offset_mm[0],
|
||||
self.target_center_offset_mm[1]]], dtype=np.float32)
|
||||
target_center_mm = target_center_mm.reshape(-1, 1, 2)
|
||||
|
||||
# 使用单应性矩阵投影到图像坐标
|
||||
target_center_img = cv2.perspectiveTransform(target_center_mm, H)
|
||||
center_x = target_center_img[0][0][0]
|
||||
center_y = target_center_img[0][0][1]
|
||||
|
||||
# 计算靶心半径(像素)
|
||||
# 使用已知物理距离和像素距离的比例
|
||||
# 选择两个标记计算比例尺
|
||||
if len(marker_centers) >= 2:
|
||||
# 使用对角线上的标记计算比例尺
|
||||
if 0 in marker_centers and 2 in marker_centers:
|
||||
p1_img = np.array(marker_centers[0])
|
||||
p2_img = np.array(marker_centers[2])
|
||||
p1_mm = np.array(self.marker_positions_mm[0])
|
||||
p2_mm = np.array(self.marker_positions_mm[2])
|
||||
elif 1 in marker_centers and 3 in marker_centers:
|
||||
p1_img = np.array(marker_centers[1])
|
||||
p2_img = np.array(marker_centers[3])
|
||||
p1_mm = np.array(self.marker_positions_mm[1])
|
||||
p2_mm = np.array(self.marker_positions_mm[3])
|
||||
else:
|
||||
# 使用任意两个标记
|
||||
keys = list(marker_centers.keys())
|
||||
p1_img = np.array(marker_centers[keys[0]])
|
||||
p2_img = np.array(marker_centers[keys[1]])
|
||||
p1_mm = np.array(self.marker_positions_mm[keys[0]])
|
||||
p2_mm = np.array(self.marker_positions_mm[keys[1]])
|
||||
|
||||
pixel_distance = np.linalg.norm(p1_img - p2_img)
|
||||
mm_distance = np.linalg.norm(p1_mm - p2_mm)
|
||||
|
||||
if mm_distance > 0:
|
||||
pixels_per_mm = pixel_distance / mm_distance
|
||||
# 标准靶心半径:10环半径约1.22cm = 12.2mm
|
||||
# 但这里我们返回一个估计值,实际环数计算在laser_manager中
|
||||
radius_mm = 122.0 # 整个靶纸的半径约200mm,但靶心区域较小
|
||||
radius = int(radius_mm * pixels_per_mm)
|
||||
else:
|
||||
radius = 100 # 默认值
|
||||
else:
|
||||
radius = 100 # 默认值
|
||||
|
||||
# 计算椭圆参数(用于透视校正)
|
||||
# 从单应性矩阵可以推导出透视变形
|
||||
ellipse_params = self._compute_ellipse_params(H, center_x, center_y)
|
||||
|
||||
if self.logger:
|
||||
self.logger.info(f"[ARUCO] 靶心计算成功: 中心=({center_x:.1f}, {center_y:.1f}), "
|
||||
f"半径={radius}px, 检测到{len(marker_centers)}个标记")
|
||||
|
||||
return (int(center_x), int(center_y)), radius, "aruco", ellipse_params
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"[ARUCO] 计算靶心失败: {e}")
|
||||
import traceback
|
||||
self.logger.error(traceback.format_exc())
|
||||
return None, None, None, None
|
||||
|
||||
def _compute_ellipse_params(self, H, center_x, center_y):
|
||||
"""
|
||||
从单应性矩阵计算椭圆参数,用于透视校正
|
||||
|
||||
Args:
|
||||
H: 单应性矩阵 (3x3)
|
||||
center_x, center_y: 靶心图像坐标
|
||||
|
||||
Returns:
|
||||
ellipse_params: ((center_x, center_y), (width, height), angle)
|
||||
"""
|
||||
try:
|
||||
# 在物理坐标系中画一个圆,投影到图像中看变成什么形状
|
||||
# 物理圆:半径10mm
|
||||
r_mm = 10.0
|
||||
angles = np.linspace(0, 2*np.pi, 16)
|
||||
circle_mm = np.array([[self.target_center_offset_mm[0] + r_mm * np.cos(a),
|
||||
self.target_center_offset_mm[1] + r_mm * np.sin(a)]
|
||||
for a in angles], dtype=np.float32)
|
||||
circle_mm = circle_mm.reshape(-1, 1, 2)
|
||||
|
||||
# 投影到图像
|
||||
circle_img = cv2.perspectiveTransform(circle_mm, H)
|
||||
circle_img = circle_img.reshape(-1, 2)
|
||||
|
||||
# 拟合椭圆
|
||||
if len(circle_img) >= 5:
|
||||
ellipse = cv2.fitEllipse(circle_img.astype(np.float32))
|
||||
return ellipse
|
||||
else:
|
||||
# 从单应性矩阵近似估计
|
||||
# 提取缩放和旋转
|
||||
# H = K * [R|t] 的近似
|
||||
# 这里简化处理:假设没有严重变形
|
||||
scale_x = np.linalg.norm(H[0, :2])
|
||||
scale_y = np.linalg.norm(H[1, :2])
|
||||
avg_scale = (scale_x + scale_y) / 2
|
||||
|
||||
width = r_mm * 2 * scale_x
|
||||
height = r_mm * 2 * scale_y
|
||||
angle = np.degrees(np.arctan2(H[1, 0], H[0, 0]))
|
||||
|
||||
return ((center_x, center_y), (width, height), angle)
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.debug(f"[ARUCO] 计算椭圆参数失败: {e}")
|
||||
return None
|
||||
|
||||
def transform_laser_point(self, laser_point, corners, ids):
|
||||
"""
|
||||
将激光点从图像坐标转换到物理坐标(毫米),再计算相对于靶心的偏移
|
||||
|
||||
Args:
|
||||
laser_point: (x, y) 激光点在图像中的坐标
|
||||
corners: 检测到的标记角点
|
||||
ids: 检测到的标记ID
|
||||
|
||||
Returns:
|
||||
(dx_mm, dy_mm) 激光点相对于靶心的偏移(毫米),或 (None, None)
|
||||
"""
|
||||
if laser_point is None or ids is None or len(ids) < 3:
|
||||
return None, None
|
||||
|
||||
try:
|
||||
# 重新计算单应性矩阵(可以优化为缓存)
|
||||
detected_ids = ids.flatten().tolist()
|
||||
image_points = []
|
||||
object_points = []
|
||||
|
||||
for i, marker_id in enumerate(detected_ids):
|
||||
if marker_id not in self.marker_ids:
|
||||
continue
|
||||
corner = corners[i][0]
|
||||
center_x = np.mean(corner[:, 0])
|
||||
center_y = np.mean(corner[:, 1])
|
||||
image_points.append([center_x, center_y])
|
||||
object_points.append(self.marker_positions_mm[marker_id])
|
||||
|
||||
if len(image_points) < 3:
|
||||
return None, None
|
||||
|
||||
image_points = np.array(image_points, dtype=np.float32)
|
||||
object_points = np.array(object_points, dtype=np.float32)
|
||||
|
||||
H, _ = cv2.findHomography(object_points, image_points, cv2.RANSAC, 5.0)
|
||||
if H is None:
|
||||
return None, None
|
||||
|
||||
# 求逆矩阵,将图像坐标转换到物理坐标
|
||||
H_inv = np.linalg.inv(H)
|
||||
|
||||
laser_img = np.array([[laser_point[0], laser_point[1]]], dtype=np.float32)
|
||||
laser_img = laser_img.reshape(-1, 1, 2)
|
||||
|
||||
laser_mm = cv2.perspectiveTransform(laser_img, H_inv)
|
||||
laser_x_mm = laser_mm[0][0][0]
|
||||
laser_y_mm = laser_mm[0][0][1]
|
||||
|
||||
# 计算相对于靶心的偏移
|
||||
# 注意:Y轴方向可能需要翻转(图像Y向下,物理Y通常向上)
|
||||
dx_mm = laser_x_mm - self.target_center_offset_mm[0]
|
||||
dy_mm = -(laser_y_mm - self.target_center_offset_mm[1]) # 翻转Y轴
|
||||
|
||||
if self.logger:
|
||||
self.logger.debug(f"[ARUCO] 激光点转换: 图像({laser_point[0]:.1f}, {laser_point[1]:.1f}) -> "
|
||||
f"物理({laser_x_mm:.1f}, {laser_y_mm:.1f}) -> "
|
||||
f"偏移({dx_mm:.1f}, {dy_mm:.1f})mm")
|
||||
|
||||
return dx_mm, dy_mm
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"[ARUCO] 激光点转换失败: {e}")
|
||||
return None, None
|
||||
|
||||
def draw_debug_info(self, frame, corners, ids, target_center=None, laser_point=None):
|
||||
"""
|
||||
在图像上绘制调试信息
|
||||
|
||||
Args:
|
||||
frame: MaixPy图像帧
|
||||
corners: 标记角点
|
||||
ids: 标记ID
|
||||
target_center: 计算的靶心位置
|
||||
laser_point: 激光点位置
|
||||
|
||||
Returns:
|
||||
绘制后的图像
|
||||
"""
|
||||
try:
|
||||
from maix import image
|
||||
img_cv = image.image2cv(frame, False, False).copy()
|
||||
|
||||
# 绘制检测到的标记
|
||||
if ids is not None:
|
||||
cv2.aruco.drawDetectedMarkers(img_cv, corners, ids)
|
||||
|
||||
# 绘制标记ID和中心
|
||||
for i, marker_id in enumerate(ids.flatten()):
|
||||
corner = corners[i][0]
|
||||
center_x = int(np.mean(corner[:, 0]))
|
||||
center_y = int(np.mean(corner[:, 1]))
|
||||
|
||||
# 绘制中心点
|
||||
cv2.circle(img_cv, (center_x, center_y), 5, (0, 255, 0), -1)
|
||||
|
||||
# 绘制ID
|
||||
cv2.putText(img_cv, f"ID:{marker_id}",
|
||||
(center_x + 10, center_y - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
||||
|
||||
# 绘制靶心
|
||||
if target_center:
|
||||
cv2.circle(img_cv, target_center, 8, (255, 0, 0), -1)
|
||||
cv2.circle(img_cv, target_center, 50, (255, 0, 0), 2)
|
||||
cv2.putText(img_cv, "TARGET", (target_center[0] + 15, target_center[1] - 15),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
|
||||
|
||||
# 绘制激光点
|
||||
if laser_point:
|
||||
cv2.circle(img_cv, (int(laser_point[0]), int(laser_point[1])), 6, (0, 0, 255), -1)
|
||||
cv2.putText(img_cv, "LASER", (int(laser_point[0]) + 10, int(laser_point[1]) - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
|
||||
|
||||
# 转换回MaixPy图像
|
||||
return image.cv2image(img_cv, False, False)
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"[ARUCO] 绘制调试信息失败: {e}")
|
||||
return frame
|
||||
|
||||
|
||||
# 创建全局单例实例
|
||||
aruco_detector = ArUcoDetector()
|
||||
|
||||
|
||||
def detect_target_with_aruco(frame, laser_point=None):
|
||||
"""
|
||||
使用ArUco标记检测靶心的便捷函数
|
||||
|
||||
Args:
|
||||
frame: MaixPy图像帧
|
||||
laser_point: 激光点坐标(可选)
|
||||
|
||||
Returns:
|
||||
(result_img, center, radius, method, best_radius1, ellipse_params)
|
||||
与detect_circle_v3保持相同的返回格式
|
||||
"""
|
||||
detector = aruco_detector
|
||||
|
||||
# 检测ArUco标记
|
||||
corners, ids, rejected = detector.detect_markers(frame)
|
||||
|
||||
# 计算靶心
|
||||
center, radius, method, ellipse_params = detector.get_target_center_from_markers(corners, ids)
|
||||
|
||||
# 绘制调试信息
|
||||
result_img = detector.draw_debug_info(frame, corners, ids, center, laser_point)
|
||||
|
||||
# 返回与detect_circle_v3相同的格式
|
||||
# best_radius1用于距离估算,这里用radius代替
|
||||
return result_img, center, radius, method, radius, ellipse_params
|
||||
|
||||
|
||||
def compute_laser_offset_aruco(laser_point, corners, ids):
|
||||
"""
|
||||
使用ArUco计算激光点相对于靶心的偏移(毫米)
|
||||
|
||||
Args:
|
||||
laser_point: (x, y) 激光点图像坐标
|
||||
corners: ArUco标记角点
|
||||
ids: ArUco标记ID
|
||||
|
||||
Returns:
|
||||
(dx_mm, dy_mm) 偏移量(毫米),或 (None, None)
|
||||
"""
|
||||
return aruco_detector.transform_laser_point(laser_point, corners, ids)
|
||||
307
at_client.py
Normal file
307
at_client.py
Normal file
@@ -0,0 +1,307 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
AT客户端模块
|
||||
负责4G模块的AT命令通信和URC解析
|
||||
"""
|
||||
import _thread
|
||||
from maix import time
|
||||
import re
|
||||
import threading
|
||||
|
||||
class ATClient:
|
||||
"""
|
||||
单读者 AT/URC 客户端:唯一读取 uart4g,避免 tcp_main/at()/OTA 抢读导致 EOF / 丢包。
|
||||
- send(cmd, expect, timeout_ms) : 发送 AT 并等待 expect
|
||||
- pop_tcp_payload() : 获取 +MIPURC:"rtcp" 的 payload(已按长度裁剪)
|
||||
- pop_http_event() : 获取 +MHTTPURC 事件(header/content)
|
||||
"""
|
||||
def __init__(self, uart_obj):
|
||||
self.uart = uart_obj
|
||||
self._cmd_lock = threading.Lock()
|
||||
self._q_lock = threading.Lock()
|
||||
self._rx = b""
|
||||
self._tcp_payloads = []
|
||||
self._http_events = []
|
||||
|
||||
# 当前命令等待状态(仅允许单命令 in-flight)
|
||||
self._waiting = False
|
||||
self._expect = b"OK"
|
||||
self._resp = b""
|
||||
|
||||
self._running = False
|
||||
|
||||
def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
_thread.start_new_thread(self._reader_loop, ())
|
||||
|
||||
def stop(self):
|
||||
self._running = False
|
||||
|
||||
def flush(self):
|
||||
"""清空内部缓存与队列(用于 OTA/异常恢复)"""
|
||||
with self._q_lock:
|
||||
self._rx = b""
|
||||
self._tcp_payloads.clear()
|
||||
self._http_events.clear()
|
||||
self._resp = b""
|
||||
|
||||
def pop_tcp_payload(self):
|
||||
with self._q_lock:
|
||||
if self._tcp_payloads:
|
||||
return self._tcp_payloads.pop(0)
|
||||
return None
|
||||
|
||||
def pop_http_event(self):
|
||||
with self._q_lock:
|
||||
if self._http_events:
|
||||
return self._http_events.pop(0)
|
||||
return None
|
||||
|
||||
def _push_tcp_payload(self, payload: bytes):
|
||||
# 注意:在 _reader_loop 内部解析 URC 时已经持有 _q_lock,
|
||||
# 这里不要再次 acquire(锁不可重入,会死锁)。
|
||||
self._tcp_payloads.append(payload)
|
||||
|
||||
def _push_http_event(self, ev):
|
||||
# 同上:避免在 _reader_loop 持锁期间二次 acquire
|
||||
self._http_events.append(ev)
|
||||
|
||||
def send(self, cmd: str, expect: str = "OK", timeout_ms: int = 2000):
|
||||
"""
|
||||
发送 AT 命令并等待 expect(子串匹配)。
|
||||
注意:expect=">" 用于等待 prompt。
|
||||
"""
|
||||
expect_b = expect.encode() if isinstance(expect, str) else expect
|
||||
with self._cmd_lock:
|
||||
# 初始化等待
|
||||
self._waiting = True
|
||||
self._expect = expect_b
|
||||
self._resp = b""
|
||||
|
||||
# 发送
|
||||
if cmd:
|
||||
# 注意:这里不要再用 uart4g_lock(否则外层已经持锁时会死锁)。
|
||||
# 写入由 _cmd_lock 串行化即可。
|
||||
self.uart.write((cmd + "\r\n").encode())
|
||||
|
||||
t0 = time.ticks_ms()
|
||||
while abs(time.ticks_diff(time.ticks_ms(), t0)) < timeout_ms:
|
||||
if (not self._waiting) or (self._expect in self._resp):
|
||||
self._waiting = False
|
||||
break
|
||||
time.sleep_ms(5)
|
||||
|
||||
# 超时也返回已收集内容(便于诊断)
|
||||
self._waiting = False
|
||||
try:
|
||||
return self._resp.decode(errors="ignore")
|
||||
except:
|
||||
return str(self._resp)
|
||||
|
||||
def _find_urc_tag(self, tag: bytes):
|
||||
"""
|
||||
只在"真正的 URC 边界"查找 tag,避免误命中 HTTP payload 内容。
|
||||
规则:tag 必须出现在 buffer 开头,或紧跟在 b"\\r\\n" 后面。
|
||||
"""
|
||||
try:
|
||||
i = 0
|
||||
rx = self._rx
|
||||
while True:
|
||||
j = rx.find(tag, i)
|
||||
if j < 0:
|
||||
return -1
|
||||
if j == 0:
|
||||
return 0
|
||||
if j >= 2 and rx[j - 2:j] == b"\r\n":
|
||||
return j
|
||||
i = j + 1
|
||||
except:
|
||||
return -1
|
||||
|
||||
def _parse_mipurc_rtcp(self):
|
||||
"""
|
||||
解析:+MIPURC: "rtcp",<link_id>,<len>,<payload...>
|
||||
之前硬编码 link_id=0 会导致在多连接/重连场景下收不到数据。
|
||||
"""
|
||||
prefix = b'+MIPURC: "rtcp",'
|
||||
i = self._find_urc_tag(prefix)
|
||||
if i < 0:
|
||||
return False
|
||||
# 丢掉前置噪声
|
||||
if i > 0:
|
||||
self._rx = self._rx[i:]
|
||||
i = 0
|
||||
|
||||
j = len(prefix)
|
||||
# 解析 link_id
|
||||
k = j
|
||||
while k < len(self._rx) and 48 <= self._rx[k] <= 57:
|
||||
k += 1
|
||||
if k == j or k >= len(self._rx):
|
||||
return False
|
||||
if self._rx[k:k+1] != b",":
|
||||
self._rx = self._rx[1:]
|
||||
return True
|
||||
try:
|
||||
link_id = int(self._rx[j:k].decode())
|
||||
except:
|
||||
self._rx = self._rx[1:]
|
||||
return True
|
||||
|
||||
# 解析 len
|
||||
j2 = k + 1
|
||||
k2 = j2
|
||||
while k2 < len(self._rx) and 48 <= self._rx[k2] <= 57:
|
||||
k2 += 1
|
||||
if k2 == j2 or k2 >= len(self._rx):
|
||||
return False
|
||||
if self._rx[k2:k2+1] != b",":
|
||||
self._rx = self._rx[1:]
|
||||
return True
|
||||
try:
|
||||
n = int(self._rx[j2:k2].decode())
|
||||
except:
|
||||
self._rx = self._rx[1:]
|
||||
return True
|
||||
|
||||
payload_start = k2 + 1
|
||||
payload_end = payload_start + n
|
||||
if len(self._rx) < payload_end:
|
||||
return False # payload 未收齐
|
||||
|
||||
payload = self._rx[payload_start:payload_end]
|
||||
# 把 link_id 一起带上,便于上层过滤(如果需要)
|
||||
self._push_tcp_payload((link_id, payload))
|
||||
self._rx = self._rx[payload_end:]
|
||||
return True
|
||||
|
||||
def _parse_mhttpurc_header(self):
|
||||
tag = b'+MHTTPURC: "header",'
|
||||
i = self._find_urc_tag(tag)
|
||||
if i < 0:
|
||||
return False
|
||||
if i > 0:
|
||||
self._rx = self._rx[i:]
|
||||
i = 0
|
||||
|
||||
# header: +MHTTPURC: "header",<id>,<code>,<hdr_len>,<hdr_text...>
|
||||
j = len(tag)
|
||||
comma_count = 0
|
||||
k = j
|
||||
while k < len(self._rx) and comma_count < 3:
|
||||
if self._rx[k:k+1] == b",":
|
||||
comma_count += 1
|
||||
k += 1
|
||||
if comma_count < 3:
|
||||
return False
|
||||
|
||||
prefix = self._rx[:k]
|
||||
m = re.search(rb'\+MHTTPURC: "header",\s*(\d+),\s*(\d+),\s*(\d+),', prefix)
|
||||
if not m:
|
||||
self._rx = self._rx[1:]
|
||||
return True
|
||||
urc_id = int(m.group(1))
|
||||
code = int(m.group(2))
|
||||
hdr_len = int(m.group(3))
|
||||
|
||||
text_start = k
|
||||
text_end = text_start + hdr_len
|
||||
if len(self._rx) < text_end:
|
||||
return False
|
||||
|
||||
hdr_text = self._rx[text_start:text_end].decode("utf-8", "ignore")
|
||||
self._push_http_event(("header", urc_id, code, hdr_text))
|
||||
self._rx = self._rx[text_end:]
|
||||
return True
|
||||
|
||||
def _parse_mhttpurc_content(self):
|
||||
tag = b'+MHTTPURC: "content",'
|
||||
i = self._find_urc_tag(tag)
|
||||
if i < 0:
|
||||
return False
|
||||
if i > 0:
|
||||
self._rx = self._rx[i:]
|
||||
i = 0
|
||||
|
||||
# content: +MHTTPURC: "content",<id>,<total>,<sum>,<cur>,<payload...>
|
||||
j = len(tag)
|
||||
comma_count = 0
|
||||
k = j
|
||||
while k < len(self._rx) and comma_count < 4:
|
||||
if self._rx[k:k+1] == b",":
|
||||
comma_count += 1
|
||||
k += 1
|
||||
if comma_count < 4:
|
||||
return False
|
||||
|
||||
prefix = self._rx[:k]
|
||||
m = re.search(rb'\+MHTTPURC: "content",\s*(\d+),\s*(\d+),\s*(\d+),\s*(\d+),', prefix)
|
||||
if not m:
|
||||
self._rx = self._rx[1:]
|
||||
return True
|
||||
urc_id = int(m.group(1))
|
||||
total_len = int(m.group(2))
|
||||
sum_len = int(m.group(3))
|
||||
cur_len = int(m.group(4))
|
||||
|
||||
payload_start = k
|
||||
payload_end = payload_start + cur_len
|
||||
if len(self._rx) < payload_end:
|
||||
return False
|
||||
|
||||
payload = self._rx[payload_start:payload_end]
|
||||
self._push_http_event(("content", urc_id, total_len, sum_len, cur_len, payload))
|
||||
self._rx = self._rx[payload_end:]
|
||||
return True
|
||||
|
||||
def _reader_loop(self):
|
||||
while self._running:
|
||||
# 关键:UART 驱动偶发 read failed,必须兜住,否则线程挂了 OTA/TCP 都会卡死
|
||||
try:
|
||||
d = self.uart.read(4096) # 8192 在一些驱动上更容易触发 read failed
|
||||
except Exception as e:
|
||||
try:
|
||||
print("[ATClient] uart read failed:", e)
|
||||
except:
|
||||
pass
|
||||
time.sleep_ms(50)
|
||||
continue
|
||||
|
||||
if not d:
|
||||
time.sleep_ms(1)
|
||||
continue
|
||||
|
||||
with self._q_lock:
|
||||
self._rx += d
|
||||
if self._waiting:
|
||||
self._resp += d
|
||||
|
||||
while True:
|
||||
progressed = (
|
||||
self._parse_mipurc_rtcp()
|
||||
or self._parse_mhttpurc_header()
|
||||
or self._parse_mhttpurc_content()
|
||||
)
|
||||
if not progressed:
|
||||
break
|
||||
|
||||
# 使用 ota_manager 访问 ota_in_progress
|
||||
try:
|
||||
from ota_manager import ota_manager
|
||||
ota_flag = ota_manager.ota_in_progress
|
||||
except:
|
||||
ota_flag = False
|
||||
|
||||
has_http_hint = (b"+MHTTP" in self._rx) or (b"+MHTTPURC" in self._rx)
|
||||
if ota_flag or has_http_hint:
|
||||
if len(self._rx) > 512 * 1024:
|
||||
self._rx = self._rx[-256 * 1024:]
|
||||
else:
|
||||
if len(self._rx) > 16384:
|
||||
self._rx = self._rx[-4096:]
|
||||
|
||||
|
||||
|
||||
33
cameraParameters.xml
Normal file
33
cameraParameters.xml
Normal file
@@ -0,0 +1,33 @@
|
||||
<?xml version="1.0"?>
|
||||
<opencv_storage>
|
||||
<calibrationDate>"Sat Apr 11 12:05:27 2026"</calibrationDate>
|
||||
<framesCount>29</framesCount>
|
||||
<cameraResolution>
|
||||
640 480</cameraResolution>
|
||||
<camera_matrix type_id="opencv-matrix">
|
||||
<rows>3</rows>
|
||||
<cols>3</cols>
|
||||
<dt>d</dt>
|
||||
<data>
|
||||
2207.9058323074869 0. 328.90661220953149 0. 2207.9058323074869
|
||||
205.49515894111076 0. 0. 1.</data></camera_matrix>
|
||||
<camera_matrix_std_dev type_id="opencv-matrix">
|
||||
<rows>4</rows>
|
||||
<cols>1</cols>
|
||||
<dt>d</dt>
|
||||
<data>
|
||||
0. 11.687428265309892 3.6908895632668468 3.597571733110271</data></camera_matrix_std_dev>
|
||||
<distortion_coefficients type_id="opencv-matrix">
|
||||
<rows>1</rows>
|
||||
<cols>5</cols>
|
||||
<dt>d</dt>
|
||||
<data>
|
||||
-0.63036604771649651 3.3832710000807449 0. 0. -0.45113389267675552</data></distortion_coefficients>
|
||||
<distortion_coefficients_std_dev type_id="opencv-matrix">
|
||||
<rows>5</rows>
|
||||
<cols>1</cols>
|
||||
<dt>d</dt>
|
||||
<data>
|
||||
0.025002349846111244 1.0651877135605927 0. 0. 0.04021252864120229</data></distortion_coefficients_std_dev>
|
||||
<avg_reprojection_error>0.28992233810828955</avg_reprojection_error>
|
||||
</opencv_storage>
|
||||
137
camera_manager.py
Normal file
137
camera_manager.py
Normal file
@@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
相机管理器模块
|
||||
提供相机和显示的统一管理和线程安全访问
|
||||
"""
|
||||
import threading
|
||||
import config
|
||||
from logger_manager import logger_manager
|
||||
|
||||
|
||||
class CameraManager:
|
||||
"""相机管理器(单例)"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(CameraManager, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 私有对象
|
||||
self._camera = None
|
||||
self._display = None
|
||||
|
||||
# 线程安全锁
|
||||
self._camera_lock = threading.Lock()
|
||||
self._display_lock = threading.Lock()
|
||||
|
||||
# 相机配置
|
||||
self._camera_width = 640
|
||||
self._camera_height = 480
|
||||
|
||||
self._initialized = True
|
||||
|
||||
# ==================== 初始化方法 ====================
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
"""获取 logger 对象"""
|
||||
return logger_manager.logger
|
||||
|
||||
def init_camera(self, width=640, height=480):
|
||||
"""初始化相机"""
|
||||
if self._camera is not None:
|
||||
return self._camera
|
||||
|
||||
from maix import camera
|
||||
|
||||
self._camera_width = width
|
||||
self._camera_height = height
|
||||
|
||||
with self._camera_lock:
|
||||
if self._camera is None:
|
||||
self._camera = camera.Camera(width, height)
|
||||
|
||||
return self._camera
|
||||
|
||||
def init_display(self):
|
||||
"""初始化显示"""
|
||||
if self._display is not None:
|
||||
return self._display
|
||||
|
||||
from maix import display
|
||||
|
||||
with self._display_lock:
|
||||
if self._display is None:
|
||||
self._display = display.Display()
|
||||
|
||||
return self._display
|
||||
|
||||
# ==================== 访问方法 ====================
|
||||
|
||||
@property
|
||||
def camera(self):
|
||||
"""获取相机实例(懒加载)"""
|
||||
if self._camera is None:
|
||||
self.init_camera()
|
||||
return self._camera
|
||||
|
||||
@property
|
||||
def display(self):
|
||||
"""获取显示实例(懒加载)"""
|
||||
if self._display is None:
|
||||
self.init_display()
|
||||
return self._display
|
||||
|
||||
# ==================== 业务方法 ====================
|
||||
|
||||
def read_frame(self):
|
||||
"""
|
||||
线程安全地读取一帧图像
|
||||
|
||||
Returns:
|
||||
frame: 图像帧对象
|
||||
"""
|
||||
with self._camera_lock:
|
||||
if self._camera is None:
|
||||
self.init_camera()
|
||||
return self._camera.read()
|
||||
|
||||
def show(self, image):
|
||||
"""
|
||||
线程安全地显示图像
|
||||
|
||||
Args:
|
||||
image: 要显示的图像对象
|
||||
"""
|
||||
with self._display_lock:
|
||||
if self._display is None:
|
||||
self.init_display()
|
||||
self._display.show(image)
|
||||
|
||||
def release(self):
|
||||
"""释放相机和显示资源(如果需要)"""
|
||||
with self._camera_lock:
|
||||
if self._camera is not None:
|
||||
# MaixPy 的 Camera 可能不需要显式释放,但可以在这里清理
|
||||
self._camera = None
|
||||
|
||||
with self._display_lock:
|
||||
if self._display is not None:
|
||||
# MaixPy 的 Display 可能不需要显式释放
|
||||
self._display = None
|
||||
|
||||
|
||||
# 创建全局单例实例
|
||||
camera_manager = CameraManager()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
337
config.py
Normal file
337
config.py
Normal file
@@ -0,0 +1,337 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
系统配置常量
|
||||
这些值在程序运行期间基本不变,或只在配置时改变
|
||||
"""
|
||||
from version import VERSION
|
||||
|
||||
# ==================== 应用配置 ====================
|
||||
APP_VERSION = VERSION
|
||||
APP_DIR = "/maixapp/apps/t11"
|
||||
LOCAL_FILENAME = APP_DIR + "/main_tmp.py"
|
||||
|
||||
# ==================== 相机配置 ====================
|
||||
# 相机初始化分辨率(CameraManager / main.py 使用)
|
||||
CAMERA_WIDTH = 640
|
||||
CAMERA_HEIGHT = 480
|
||||
|
||||
# 三角形检测缩图比例:默认按相机最长边缩到 1/2(性能更稳;可按需调整)
|
||||
# 取值范围建议 (0.25 ~ 1.0];1.0 表示不缩图
|
||||
TRIANGLE_DETECT_SCALE = 0.4
|
||||
|
||||
# ==================== 服务器配置 ====================
|
||||
# SERVER_IP = "stcp.shelingxingqiu.com"
|
||||
SERVER_IP = "www.shelingxingqiu.com"
|
||||
SERVER_PORT = 50005
|
||||
HEARTBEAT_INTERVAL = 15 # 心跳间隔(秒)
|
||||
|
||||
# WiFi 质量评估(开机先尝试 WiFi;质量差且 4G 可用则切到 4G,本次上电直至关机锁定 4G)
|
||||
WIFI_QUALITY_RTT_SAMPLES = 3 # 到业务服务器 TCP 建连耗时采样次数,取中位数
|
||||
WIFI_QUALITY_RTT_BAD_MS = 600.0 # 中位数超过此值认为延迟过高
|
||||
WIFI_QUALITY_RTT_WARN_MS = 350.0 # 与 RSSI 联合:超过此值且信号弱也判为差
|
||||
WIFI_QUALITY_RSSI_BAD_DBM = -80.0 # 低于此 dBm(更负更差)视为信号弱
|
||||
WIFI_QUALITY_USE_RSSI = True # 是否把 RSSI 纳入综合判定
|
||||
|
||||
# WiFi 热点配网(手机连设备 AP,浏览器提交路由器 SSID/密码;仅 GET/POST,标准库 socket)
|
||||
WIFI_CONFIG_AP_FALLBACK = True # # WiFi 配网失败时,是否退回热点模式,并等待重新配网
|
||||
WIFI_AP_FALLBACK_WAIT_SEC = 5 # 等待5秒后再检测STA/4G
|
||||
WIFI_CONFIG_AP_TIMEOUT = 5 # 热点模式超时时间(秒)
|
||||
WIFI_CONFIG_AP_ENABLED = True # True=启动时开热点并起迷你 HTTP 配网服务
|
||||
WIFI_CONFIG_AP_SSID = "ArcherySetup" # 设备发出的热点名称
|
||||
WIFI_CONFIG_AP_PASSWORD = "12345678" # 热点密码(WPA2 通常至少 8 位)
|
||||
WIFI_CONFIG_HTTP_HOST = "0.0.0.0" # HTTP 监听地址
|
||||
WIFI_CONFIG_HTTP_PORT = 8080 # 默认 8080,避免占用 80 需 root
|
||||
WIFI_CONFIG_AP_IP = "192.168.66.1" # 与 MaixPy Wifi.start_ap 默认一致,手机访问 http://192.168.66.1:8080/
|
||||
# 这个地址需要和 /boot/wifi.ipv4_prefix 配合,才能正确访问。
|
||||
# 比如说 /boot/wifi.ipv4_prefix 需要写成 192.168.66
|
||||
# ===== TCP over SSL(TLS) 配置 =====
|
||||
USE_TCP_SSL = True # True=按手册走 MSSLCFG/MIPCFG 绑定 SSL
|
||||
TCP_LINK_ID = 2 #
|
||||
TCP_SSL_PORT = 50006 # TLS 端口(不一定必须 443,以服务器为准)
|
||||
|
||||
# SSL profile
|
||||
SSL_ID = 1 # ssl_id=1
|
||||
SSL_AUTH_MODE = 1 # 1=单向认证(验证服务器),2=双向
|
||||
SSL_VERIFY_MODE = 1 # 0=不验(仅测试用);1=写入并使用 CA 证书
|
||||
|
||||
SSL_CERT_FILENAME = "server.pem" # 模组里证书名(MSSLCERTWR / MSSLCFG="cert" 用)
|
||||
SSL_CERT_PATH = APP_DIR + "/server.pem" # 设备文件系统里 CA 证书路径(你自己放进去)
|
||||
# MIPOPEN 末尾的参数在不同固件里含义可能不同;按你手册例子保留
|
||||
MIPOPEN_TAIL = ",,0"
|
||||
|
||||
# ==================== 文件路径配置 ====================
|
||||
CONFIG_FILE = "/root/laser_config.json"
|
||||
LOG_FILE = APP_DIR + "/app.log"
|
||||
BACKUP_BASE = APP_DIR + "/backups"
|
||||
|
||||
# ==================== 硬件配置 ====================
|
||||
# UART配置
|
||||
UART4G_DEVICE = "/dev/ttyS2"
|
||||
UART4G_BAUDRATE = 115200
|
||||
DISTANCE_SERIAL_DEVICE = "/dev/ttyS1"
|
||||
DISTANCE_SERIAL_BAUDRATE = 9600
|
||||
|
||||
# I2C:板载 WiFi 方案固定 I2C5,引脚 A15(SCL) / A27(SDA),供 INA226 等
|
||||
I2C_BUS_NUM = 5
|
||||
|
||||
INA226_ADDR = 0x40
|
||||
# False=完全不访问 INA226(无电源计量板或未供电时避免 ~2.5s writeto 重试与底层 write failed 日志);量产有芯片时设为 True
|
||||
INA226_ENABLE = True
|
||||
# True=整总线 I2C scan 探测 INA226(在部分平台上极慢,可达 ~90s+);False=仅对 INA226_ADDR 快速探测(writeto 空写)
|
||||
INA226_PROBE_FULL_BUS_SCAN = False
|
||||
REG_CONFIGURATION = 0x00
|
||||
REG_BUS_VOLTAGE = 0x02
|
||||
REG_CURRENT = 0x04 # 电流寄存器
|
||||
REG_CALIBRATION = 0x05
|
||||
CALIBRATION_VALUE = 0x1400
|
||||
|
||||
# ==================== 空气传感器配置 ====================
|
||||
ADC_TRIGGER_THRESHOLD = 2700 # TODO:4096只是用于测试,因为最大值是4095,这个值是永远不会触发的,最终需要改为正常值
|
||||
AIR_PRESSURE_lOG = False # TODO: 在正式环境中关闭
|
||||
AIR_PRESSURE_HARDWARE_MAX = 10
|
||||
# ADC配置
|
||||
ADC_CHANNEL = 0
|
||||
ADC_LASER_THRESHOLD = 3000
|
||||
|
||||
# ==================== 激光配置 ====================
|
||||
MODULE_ADDR = 0x00
|
||||
LASER_ON_CMD = bytes([0xAA, MODULE_ADDR, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x01, 0xC1])
|
||||
LASER_OFF_CMD = bytes([0xAA, MODULE_ADDR, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x00, 0xC0])
|
||||
DISTANCE_QUERY_CMD = bytes([0xAA, MODULE_ADDR, 0x00, 0x20, 0x00, 0x01, 0x00, 0x00, 0x21]) # 激光测距查询命令
|
||||
DISTANCE_RESPONSE_LEN = 13 # 激光测距响应数据长度(字节)
|
||||
DEFAULT_LASER_POINT = (320, 245) # 默认激光中心点
|
||||
|
||||
# 硬编码激光点配置
|
||||
HARDCODE_LASER_POINT = True # 是否使用硬编码的激光点(True=使用硬编码值,False=使用校准值)
|
||||
HARDCODE_LASER_POINT_VALUE = (320, 296) # 硬编码的激光点坐标(315, 245) # # 硬编码的激光点坐标 (x, y)
|
||||
|
||||
# 激光点检测配置
|
||||
LASER_DETECTION_THRESHOLD = 140 # 红色通道阈值(默认120,可调整,范围建议:100-150)
|
||||
LASER_RED_RATIO = 1.5 # 红色相对于绿色/蓝色的倍数要求(默认1.5,可调整,范围建议:1.3-2.0)
|
||||
LASER_SEARCH_RADIUS = 50 # 搜索半径(像素),从图像中心开始搜索(默认20,限制激光点不能偏离中心太远)
|
||||
LASER_MAX_DISTANCE_FROM_CENTER = 50 # 激光点距离中心的最大允许距离(像素),超过此距离则拒绝(默认20)
|
||||
LASER_OVEREXPOSED_THRESHOLD = 200 # 过曝红色判断阈值(默认200,接近白色时的阈值)
|
||||
LASER_OVEREXPOSED_DIFF = 10 # 过曝红色时,r 与 g/b 的最小差值(默认10)
|
||||
LASER_REQUIRE_IN_ELLIPSE = False # 是否要求激光点必须在黄心椭圆内(True=必须,False=不要求)
|
||||
LASER_USE_ELLIPSE_FITTING = True # 是否使用椭圆拟合方法查找激光点(True=椭圆拟合更准确,False=最亮点方法)
|
||||
LASER_MIN_AREA = 5 # 激光点区域的最小面积(像素),小于此值认为是噪声(默认5)
|
||||
LASER_DRAW_ELLIPSE = True # 是否在图像上绘制激光点的拟合椭圆(True=绘制,False=不绘制)
|
||||
|
||||
# ==================== 视觉检测配置 ====================
|
||||
FOCAL_LENGTH_PIX = 2250.0 # 焦距(像素)
|
||||
REAL_RADIUS_CM = 20 # 靶心实际半径(厘米)
|
||||
|
||||
# 图像清晰度检测配置
|
||||
IMAGE_SHARPNESS_THRESHOLD = 100.0 # 清晰度阈值,低于此值认为图像模糊
|
||||
# 清晰图像通常 > 200,模糊图像通常 < 100
|
||||
|
||||
# 激光与摄像头物理位置配置
|
||||
LASER_CAMERA_OFFSET_CM = 1.4 # 激光在摄像头下方的物理距离(厘米),正值表示激光在摄像头下方
|
||||
IMAGE_CENTER_X = 320 # 图像中心 X 坐标
|
||||
IMAGE_CENTER_Y = 240 # 图像中心 Y 坐标
|
||||
|
||||
# ==================== 三角形四角标记:单应性偏移 + PnP 估距 ====================
|
||||
# 依赖 cameraParameters.xml(相机内参)与 triangle_positions.json(四角物方坐标,厘米或毫米见 JSON 约定)。
|
||||
# 部署时请把这两个文件放到 APP_DIR(与 main 同应用目录),或改下面路径为设备上的实际绝对路径。
|
||||
USE_TRIANGLE_OFFSET = True # False 时仅走黄心圆/椭圆 + 半径估距,不使用三角形路径
|
||||
CAMERA_CALIB_XML = APP_DIR + "/cameraParameters.xml"
|
||||
TRIANGLE_POSITIONS_JSON = APP_DIR + "/triangle_positions.json"
|
||||
# 检测到的三角形边长在图像中的像素范围,分辨率或靶纸占比变化时可微调
|
||||
TRIANGLE_SIZE_RANGE = (8, 500)
|
||||
# PnP 距离合理性检查(可选):超出范围时认为本次检测有误,回退圆心算法
|
||||
# 设为 0 表示不启用(主要防线是单应矩阵 sx/sy 比值检查,无需提前知道距离)
|
||||
# 如果射箭距离很固定,可设具体范围(如 min=2.5, max=6.0)作为额外保险
|
||||
TRIANGLE_DISTANCE_MIN_M = 0.0 # 0=不启用下限检查
|
||||
TRIANGLE_DISTANCE_MAX_M = 0.0 # 0=不启用上限检查
|
||||
# 三角形检测兜底增强:CLAHE(更鲁棒但更慢)。颜色阈值修复后通常不需要,保持关闭以优先速度。
|
||||
TRIANGLE_ENABLE_CLAHE_FALLBACK = False
|
||||
# 三角形检测调试:保存 Otsu 二值化图像(临时调试用,定位后关闭)
|
||||
TRIANGLE_SAVE_DEBUG_IMAGE = False
|
||||
# 三角形颜色过滤阈值(三角形内部灰度判定)
|
||||
# 如果三角形标记印刷较浅/环境较亮,可放宽:
|
||||
# max_interior_gray: 三角形内部平均灰度上限(越大越宽松,90→130 适应浅色印刷)
|
||||
# dark_pixel_gray: "暗像素"灰度判定阈值(越大越宽松,80→130)
|
||||
# min_dark_ratio: 暗像素占比下限(越小越宽松,0.70→0.30)
|
||||
TRIANGLE_MAX_INTERIOR_GRAY = 130
|
||||
TRIANGLE_DARK_PIXEL_GRAY = 130
|
||||
TRIANGLE_MIN_DARK_RATIO = 0.30
|
||||
# 三角形相对对比度阈值:内部比周围暗多少灰度值才认为有效(0=禁用相对对比度)
|
||||
TRIANGLE_MIN_CONTRAST_DIFF = 15
|
||||
# 三角形形状约束容差(等腰直角判定松紧度)
|
||||
# 增大可容忍轮廓轻微变形(印刷不均、阴影局部切角),减少"差一点点就失败"的漏检
|
||||
# 建议范围:0.20(原始/严格) ~ 0.30(宽松);超过 0.35 容易误检非三角形
|
||||
TRIANGLE_SHAPE_LEG_TOLERANCE = 0.25 # 两直角边长度比例容差(原 0.20)
|
||||
TRIANGLE_SHAPE_HYP_TOLERANCE = 0.25 # 斜边与期望长度比例容差(原 0.20)
|
||||
TRIANGLE_SHAPE_COS_TOLERANCE = 0.25 # 直角余弦绝对值上限(原 0.20,越小越严格)
|
||||
# 三角形检测主超时(毫秒):join 等待子线程的最长时间。
|
||||
# 整段 try_triangle_scoring 含「多路径二值化 + C(n,4) 四角评分 + 单应性 + PnP」,往往比黄心圆检测慢。
|
||||
# 建议设为实测最坏耗时的 1.2 倍;超时后圆心检测仍会并行跑完,跑完后若三角形已结束则优先用三角形。
|
||||
TRIANGLE_TIMEOUT_MS = 1000
|
||||
# True=打印各阶段耗时(ms),用于定位瓶颈;稳定后可 False 减少日志
|
||||
TRIANGLE_TIMING_LOG = True
|
||||
# True=Stage2 每个子框内传统三角失败时打一条统计(Otsu/Adaptive 下轮廓数与各拒绝原因计数)
|
||||
TRIANGLE_LOG_STAGE2_PATCH_REJECT = True
|
||||
|
||||
# 仅检出 3 个真实三角时:是否在预测位置附近做小 ROI(Otsu/adaptive)再搜第 4 个真实三角。
|
||||
# False=跳过该搜索,直接用几何推算的虚拟第 4 点(offset_method=triangle_homography_3pt),省 ~10~120ms;若实测偏移可接受可关。
|
||||
TRIANGLE_FOURTH_ROI_SEARCH_ENABLE = False
|
||||
|
||||
# ── 轻量锐化(Unsharp Mask)──────────────────────────────────────────────────
|
||||
# 目的:轻度/中度模糊时增强边缘,让三角形轮廓更易被 approxPolyDP 检出。
|
||||
# 严重运动模糊时反而会放大噪声,建议搭配 sharpness 检测自动触发(见下)。
|
||||
# YOLO 裁切后图已较清晰时可 False,省去 Unsharp 开销并减轻振铃。
|
||||
TRIANGLE_SHARPEN_ENABLE = False # False=关闭锐化(彻底跳过计算,最省时)
|
||||
# 仅当帧清晰度(Laplacian 方差)低于此值时才锐化;高于此值说明图片本身够清晰,不动
|
||||
# 0=总是锐化;建议 50~150;对应日志中 [TRI] sharpness=xxx
|
||||
TRIANGLE_SHARPEN_THRESHOLD = 0.0 # 0=总是锐化(不做 Laplacian 判断,省去计算)
|
||||
# Unsharp Mask 高斯核 sigma(越大锐化越强,通常 1.0~3.0)
|
||||
TRIANGLE_SHARPEN_SIGMA = 2.0
|
||||
# Unsharp Mask 强度系数(越大锐化越猛,通常 1.2~2.0;>2 易产生振铃)
|
||||
TRIANGLE_SHARPEN_STRENGTH = 1.5
|
||||
|
||||
# 三角形检测用灰度来源(ROI 裁切、缩放到 img_det 之后;与 vision 一致按 RGB 输入)
|
||||
# rgb — 常规 cv2.cvtColor RGB2GRAY
|
||||
# v_suppress — HSV 的 V:亮度 >= TRIANGLE_HSV_V_SUPPRESS_ABOVE 的像素灰度强制为 255,压制黄/红/蓝等亮环后再走原有 Otsu 流水线
|
||||
# fallback_v_suppress — 先用 rgb 跑 detect;若检出三角形 <3,再用 v_suppress 重跑一遍(省平均耗时,坏帧可多救一点)
|
||||
# try_both — rgb 与 v_suppress 各完整跑一遍 detect_triangle_markers,取检出数更多一侧(平局保留 rgb);耗时约 2 倍,用于对比效果
|
||||
TRIANGLE_GRAY_MODE = "v_suppress"
|
||||
TRIANGLE_HSV_V_SUPPRESS_ABOVE = 200 # 0~255;偏高则环残留多,偏低则可能伤到暗三角边缘,建议 180~220 扫一圈
|
||||
|
||||
# 三角形检测性能/鲁棒性参数(偏向速度的默认值)
|
||||
# 说明:
|
||||
# - Otsu 是最快的全局阈值;adaptiveThreshold 更鲁棒但更慢
|
||||
# - filtered 候选过多时,枚举 C(n,4) 会变慢,需限幅
|
||||
TRIANGLE_EARLY_EXIT_CANDIDATES = 3 # 找到3个候选即停(第4个由几何推算);原来4需跑完全adaptive
|
||||
TRIANGLE_ADAPTIVE_BLOCK_SIZES = (11,) # 只用1个block_size;原(11,21)跑两遍adaptive
|
||||
TRIANGLE_MAX_FILTERED_FOR_COMBO = 10 # 参与四点组合评分的最大候选数(超过则截断到最可能的一部分)
|
||||
|
||||
# ROI 局部阈值:四个象限各自 Otsu(+ 可选 ROI 内 adaptive),再合并候选。
|
||||
# 顺序:紧接在全局 Otsu 之后、整图 adaptive 之前(见 triangle_target.detect_triangle_markers)。
|
||||
# 用途:阴阳脸/大阴影下往往比「先整图 adaptive」更省时间且更稳;整图 adaptive 最慢,作补充。
|
||||
#
|
||||
# YOLO 已裁到靶区时,整幅小图上单一全局 Otsu 容易把环与四角揉在一个阈值里;可跳过第一轮「全局轮廓提取」,
|
||||
# 直接进入下面四象限 ROI Otsu(仍会算全局 b_otsu 供 relaxed approxPolyDP 回退)。整图模式勿开。
|
||||
TRIANGLE_SKIP_GLOBAL_OTSU_EXTRACT_ON_YOLO_ROI = True
|
||||
|
||||
TRIANGLE_ROI_ENABLED = False
|
||||
TRIANGLE_ROI_MIN_CANDIDATES = 3 # 候选数低于此值时启用 ROI 局部阈值(需至少 3 个点才能三角解算)
|
||||
TRIANGLE_ROI_OVERLAP_RATIO = 0.08 # 象限 ROI 的重叠比例(避免角标落在分割边界被切断)
|
||||
TRIANGLE_ROI_USE_ADAPTIVE = False # ROI 内关闭 adaptive(只跑ROI Otsu,省去4×adaptive);遇到阴阳脸再开
|
||||
|
||||
# 多路径融合:不同二值化路径若得到相近中心(dedup 格点),累加 path_votes,后续优先参与四点组合。
|
||||
TRIANGLE_MULTI_PATH_VOTE = True
|
||||
|
||||
# 失败回退(仍不足 TRIANGLE_FALLBACK_MIN_CANDIDATES 时按序尝试,每条仅在前序仍不足时执行)
|
||||
TRIANGLE_FALLBACK_MIN_CANDIDATES = 3
|
||||
# 对同一幅 Otsu 二值图用更宽松的 approxPolyDP,找回被“切角”的轮廓
|
||||
TRIANGLE_FALLBACK_RELAXED_EPS = True
|
||||
TRIANGLE_RELAXED_POLY_EPS_SCALE = 1.65
|
||||
# Black-hat(顶帽逆):突出比周围暗的斑块,再 Otsu;对阴影/照度不均往往有效,略慢于纯 Otsu
|
||||
TRIANGLE_FALLBACK_BLACKHAT = True
|
||||
TRIANGLE_BLACKHAT_KERNEL_FRAC = 0.018 # 核大小 ≈ min(h,w)*frac,取奇数,范围约 [7, 31]
|
||||
|
||||
# ── YOLO(NPU) 靶环 ROI → 裁剪后再跑三角形(减小 CPU 处理面积)──────────────────
|
||||
# 日志里 net_in=W×H 来自 .mud 模型(det.input_width/height),不是这里配置的。
|
||||
TRIANGLE_YOLO_ROI_ENABLE = True
|
||||
TRIANGLE_YOLO_MODEL_PATH = APP_DIR + "/model_270139.mud"
|
||||
# 参与 ROI 的类别:多类时只填「整靶/靶环」的 id;不要填角标类,否则 union 仍可对,但 largest 会偏小。
|
||||
TRIANGLE_YOLO_RING_CLASS_IDS = (0,)
|
||||
TRIANGLE_YOLO_CONF_TH = 0.7
|
||||
TRIANGLE_YOLO_IOU_TH = 0.45
|
||||
# YOLO 首次/临界帧可能在高阈值下 0 框;启用后仅在 0 候选时用较低阈值重试一次。
|
||||
# 后续仍会经过 min_box_side、ROI aspect、三角形几何校验,避免直接放大假阳性。
|
||||
TRIANGLE_YOLO_RETRY_ON_EMPTY = True
|
||||
TRIANGLE_YOLO_RETRY_CONF_TH = 0.5
|
||||
TRIANGLE_YOLO_ROI_MARGIN_FRAC = 0.11
|
||||
# union: 所有候选框外接矩形(一类多框:环+四角);largest: 只取面积最大的框
|
||||
TRIANGLE_YOLO_ROI_MERGE_MODE = "union"
|
||||
# native: Maix 已将框映射到相机分辨率;letterbox: 框在网络输入坐标需逆变换(重复映射会出细条 ROI)
|
||||
TRIANGLE_YOLO_COORD_MODE = "native"
|
||||
# 参与 ROI 合并前丢弃过小的框(低 conf 时边角 1×1 假阳性)
|
||||
TRIANGLE_YOLO_MIN_BOX_SIDE_PX = 8
|
||||
TRIANGLE_YOLO_REJECT_BAD_ROI = True
|
||||
# try_triangle_scoring 收到 ROI 后裁剪的最小边长(像素),过小则退回整图
|
||||
TRIANGLE_CROP_ROI_MIN_SIDE_PX = 64
|
||||
# 射箭保存图 / 预览上绘制 YOLO 靶环 ROI 矩形 (x0,y0,x1,y1),核对是否裁准;不需要时改 False
|
||||
TRIANGLE_YOLO_DRAW_ROI_ON_SHOT = True
|
||||
# 开机阶段预加载 YOLO detector;detect 使用 dual_buff=False,避免返回上一帧结果。
|
||||
TRIANGLE_YOLO_PRELOAD_ON_BOOT = True
|
||||
|
||||
# ── 第二段 YOLO:仅在 Stage1 裁切出的靶环图上推理(与合成 stage2 训练数据一致)→ 子框内传统算法取直角点 ──
|
||||
# Stage1 靶环裁切内如何找黑三角标记(对比耗时时可切换):
|
||||
# "yolo" — 调 Stage2 黑三角模型得子框,再子框内传统提取(需 TRIANGLE_BLACK_YOLO_ENABLE=True)。
|
||||
# "traditional" — 不调 Stage2 模型;仅在 Stage1 ROI 整幅上跑传统 detect_triangle_markers(与 yolo 路径对比用)。
|
||||
TRIANGLE_BLACK_TRIANGLE_LOCATE_MODE = "traditional"
|
||||
# True 时每箭另打一枪端到端耗时:yolo_ring + yolo_black + try_triangle_scoring 墙钟(毫秒)
|
||||
TRIANGLE_LOG_E2E_TIMING = True
|
||||
TRIANGLE_BLACK_YOLO_ENABLE = True
|
||||
TRIANGLE_BLACK_YOLO_MODEL_PATH = APP_DIR + "/model_270820.mud"
|
||||
TRIANGLE_BLACK_YOLO_CLASS_IDS = (0,)
|
||||
TRIANGLE_BLACK_YOLO_CONF_TH = 0.5
|
||||
TRIANGLE_BLACK_YOLO_IOU_TH = 0.45
|
||||
# Maix YOLOv5 detect 返回的框已映射到传入的 Stage1 裁切图坐标;contain/letterbox 是模型内部预处理。
|
||||
TRIANGLE_BLACK_YOLO_COORD_MODE = "native"
|
||||
# 子框相对 YOLO 框的扩展(在靶环裁切图坐标系下),利于传统算法取边
|
||||
TRIANGLE_BLACK_YOLO_BOX_MARGIN_FRAC = 0.08
|
||||
TRIANGLE_BLACK_YOLO_MIN_BOX_SIDE_PX = 6.0
|
||||
# 子框传统检测不足 3 个时是否回退为「整幅靶环 ROI」上的原 detect_triangle_markers
|
||||
TRIANGLE_BLACK_YOLO_FALLBACK_ON_PATCH_FAIL = True
|
||||
# Stage2 子框内传统提取使用的灰度(有缩略时默认在 Stage1 全分辨率灰度上切片):
|
||||
# "rgb" — 仅用 RGB→灰度(不再做 Unsharp、不做 V 抑制),最省 CPU(推荐子框已对准黑三角时)。
|
||||
# "global" — 与整幅 ROI 三角流程同一张 gray(含 TRIANGLE_GRAY_MODE 的 v_suppress 与锐化);更稳但更耗时。
|
||||
TRIANGLE_BLACK_YOLO_PATCH_GRAY_SOURCE = "rgb"
|
||||
# Stage2 子框内轮廓→三角形:approxPolyDP 的 ε=周长×FRAC×mult。边模糊时略增大 FRAC 或保留多级 mult。
|
||||
TRIANGLE_PATCH_APPROXPOLY_FRAC = 0.055
|
||||
TRIANGLE_PATCH_APPROXPOLY_RELAX_MULTS = (1.0, 1.3, 1.65)
|
||||
# Otsu/Adaptive 前对子框灰度轻模糊:0=关闭;3 或 5=Gaussian ksize(须为奇数),压锯齿利于收成 3 顶点
|
||||
TRIANGLE_PATCH_PRE_BLUR_KSIZE = 0
|
||||
TRIANGLE_BLACK_YOLO_PRELOAD_ON_BOOT = True
|
||||
# 每箭是否在日志中打印黑三角 detect 统计(raw/类过滤/是否在环内);调通后可 False 减日志
|
||||
TRIANGLE_BLACK_YOLO_LOG_EACH_SHOT = True
|
||||
# True=每次射箭将 Stage1 裁切图(黑三角模型输入)存为 JPEG;调试用,量产请 False
|
||||
TRIANGLE_BLACK_YOLO_SAVE_ROI_CROP = True
|
||||
# 存盘目录;空字符串表示使用 PHOTO_DIR + "/stage2_roi"
|
||||
TRIANGLE_BLACK_YOLO_ROI_CROP_DIR = ""
|
||||
# 存盘 JPEG 上绘制 Stage2(黑三角 YOLO)最终子框(绿框 + s2_0… 标签)
|
||||
TRIANGLE_BLACK_YOLO_SAVE_ROI_DRAW_BOXES = True
|
||||
|
||||
FLASH_LASER_WHILE_SHOOTING = False # 是否在拍摄时闪一下激光(True=闪,False=不闪)
|
||||
FLASH_LASER_DURATION_MS = 1000 # 闪一下激光的持续时间(毫秒)
|
||||
|
||||
# ==================== 显示配置 ====================
|
||||
LASER_COLOR = (0, 255, 0) # RGB颜色
|
||||
LASER_THICKNESS = 1
|
||||
LASER_LENGTH = 2
|
||||
|
||||
# ==================== 图像保存配置 ====================
|
||||
SAVE_IMAGE_ENABLED = True # 是否保存图像(True=保存,False=不保存)
|
||||
PHOTO_DIR = "/root/phot" # 照片存储目录
|
||||
MAX_IMAGES = 1000
|
||||
# Stage2 调试目录(默认 PHOTO_DIR/stage2_roi)内 JPEG 最多保留张数;None 表示与 MAX_IMAGES 相同
|
||||
TRIANGLE_BLACK_YOLO_STAGE2_ROI_MAX_IMAGES = None
|
||||
|
||||
SHOW_CAMERA_PHOTO_WHILE_SHOOTING = False # 是否在拍摄时显示摄像头图像(True=显示,False=不显示),建议在连着USB测试过程中打开
|
||||
|
||||
# ==================== OTA配置 ====================
|
||||
MAX_BACKUPS = 5
|
||||
LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
|
||||
LOG_BACKUP_COUNT = 5
|
||||
|
||||
# ==================== 引脚映射配置(板载 WiFi,I2C5)====================
|
||||
PIN_MAPPINGS = {
|
||||
"A18": "UART1_RX",
|
||||
"A19": "UART1_TX",
|
||||
"A29": "UART2_RX",
|
||||
"A28": "UART2_TX",
|
||||
"A15": "I2C5_SCL",
|
||||
"A27": "I2C5_SDA",
|
||||
"A24": "GPIOA24", # 电源板关机控制
|
||||
}
|
||||
|
||||
# ==================== 电源配置 ====================
|
||||
AUTO_POWER_OFF_IN_SECONDS = 10 * 60 # 自动关机时间(秒),0表示不自动关机
|
||||
|
||||
BATTERY_SOC_LPF_ALPHA = 0.5
|
||||
BATTERY_SOC_AVG_WINDOW = 5
|
||||
|
||||
72
cpp_ext/CMakeLists.txt
Normal file
72
cpp_ext/CMakeLists.txt
Normal file
@@ -0,0 +1,72 @@
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
project(archery_netcore CXX)
|
||||
|
||||
set(CMAKE_SYSTEM_NAME Linux)
|
||||
set(CMAKE_SYSTEM_PROCESSOR riscv64)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
if(NOT DEFINED PY_INCLUDE_DIR)
|
||||
message(FATAL_ERROR "PY_INCLUDE_DIR not set")
|
||||
endif()
|
||||
if(NOT DEFINED PY_LIB)
|
||||
message(FATAL_ERROR "PY_LIB not set")
|
||||
endif()
|
||||
if(NOT DEFINED PY_EXT_SUFFIX)
|
||||
message(FATAL_ERROR "PY_EXT_SUFFIX not set")
|
||||
endif()
|
||||
if(NOT DEFINED MAIXCDK_PATH)
|
||||
message(FATAL_ERROR "MAIXCDK_PATH not set (need components/3rd_party/pybind11)")
|
||||
endif()
|
||||
|
||||
add_library(archery_netcore MODULE
|
||||
archery_netcore.cpp
|
||||
native_logger.cpp
|
||||
utils.cpp
|
||||
decrypt_ota_file.cpp
|
||||
msg_handler.cpp
|
||||
tcp_ssl_password.cpp
|
||||
)
|
||||
|
||||
target_include_directories(archery_netcore PRIVATE
|
||||
"${PY_INCLUDE_DIR}"
|
||||
"${MAIXCDK_PATH}/components/3rd_party/pybind11/pybind11/include"
|
||||
"${MAIXCDK_PATH}/components/3rd_party/openssl/include"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/third_party" # 添加 nlohmann/json 路径
|
||||
)
|
||||
|
||||
# 尽量减少 .so 体积并增加逆向成本
|
||||
target_compile_options(archery_netcore PRIVATE
|
||||
-Os
|
||||
-ffunction-sections
|
||||
-fdata-sections
|
||||
-fvisibility=hidden
|
||||
-fvisibility-inlines-hidden
|
||||
)
|
||||
target_link_options(archery_netcore PRIVATE
|
||||
-Wl,--gc-sections
|
||||
-Wl,-s
|
||||
)
|
||||
|
||||
set_target_properties(archery_netcore PROPERTIES
|
||||
PREFIX ""
|
||||
SUFFIX "${PY_EXT_SUFFIX}"
|
||||
)
|
||||
|
||||
# OpenSSL (for AES-256-GCM decrypt)
|
||||
# 使用 MaixCDK 提供的 OpenSSL 库(在 so/maixcam 目录下)
|
||||
set(OPENSSL_LIB_DIR "${MAIXCDK_PATH}/components/3rd_party/openssl/so/maixcam")
|
||||
if(EXISTS "${OPENSSL_LIB_DIR}/libcrypto.so")
|
||||
target_link_directories(archery_netcore PRIVATE "${OPENSSL_LIB_DIR}")
|
||||
target_link_libraries(archery_netcore PRIVATE "${PY_LIB}" crypto ssl)
|
||||
message(STATUS "Using OpenSSL from MaixCDK: ${OPENSSL_LIB_DIR}")
|
||||
else()
|
||||
# Fallback: 尝试 find_package 或系统库
|
||||
find_package(OpenSSL QUIET)
|
||||
if(OpenSSL_FOUND)
|
||||
target_link_libraries(archery_netcore PRIVATE "${PY_LIB}" OpenSSL::Crypto OpenSSL::SSL)
|
||||
else()
|
||||
message(WARNING "OpenSSL not found in MaixCDK, trying system libraries (may fail)")
|
||||
target_link_libraries(archery_netcore PRIVATE "${PY_LIB}" crypto ssl)
|
||||
endif()
|
||||
endif()
|
||||
117
cpp_ext/archery_netcore.cpp
Normal file
117
cpp_ext/archery_netcore.cpp
Normal file
@@ -0,0 +1,117 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // 支持 std::vector, std::map 等
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <cstring>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <array>
|
||||
|
||||
#include "msg_handler.hpp"
|
||||
#include "native_logger.hpp"
|
||||
#include "decrypt_ota_file.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "tcp_ssl_password.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
using json = nlohmann::json;
|
||||
|
||||
namespace {
|
||||
// 配置项
|
||||
const std::string _cfg_server_ip = "www.shelingxingqiu.com";
|
||||
const int _cfg_server_port = 50005;
|
||||
const std::string _cfg_device_id_file = "/device_key";
|
||||
|
||||
|
||||
}
|
||||
|
||||
// 定义获取配置的函数
|
||||
py::dict get_config() {
|
||||
py::dict config;
|
||||
config["SERVER_IP"] = _cfg_server_ip;
|
||||
config["SERVER_PORT"] = _cfg_server_port;
|
||||
return config;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
PYBIND11_MODULE(archery_netcore, m) {
|
||||
m.doc() = "Archery net core (native, pybind11).";
|
||||
|
||||
// Optional: configure native logger from Python.
|
||||
// Default log file: /maixapp/apps/t11/netcore.log
|
||||
m.def("set_log_file", [](const std::string& path) { netcore::set_log_file(path); }, py::arg("path"));
|
||||
m.def("set_log_level", [](int level) {
|
||||
if (level < 0) level = 0;
|
||||
if (level > 3) level = 3;
|
||||
netcore::set_log_level(static_cast<netcore::LogLevel>(level));
|
||||
}, py::arg("level"));
|
||||
m.def("log_test", [](const std::string& msg) {
|
||||
netcore::log_info(std::string("log_test: ") + msg);
|
||||
}, py::arg("msg"));
|
||||
|
||||
m.def("make_packet", &netcore::make_packet,
|
||||
"Pack TCP packet: header (len+type+checksum) + JSON body",
|
||||
py::arg("msg_type"), py::arg("body_dict"));
|
||||
|
||||
m.def("parse_packet", &netcore::parse_packet,
|
||||
"Parse TCP packet, return (msg_type, body_dict)");
|
||||
|
||||
m.def("get_config", &get_config, "Get system configuration");
|
||||
|
||||
m.def(
|
||||
"calculate_tcp_ssl_password",
|
||||
&netcore::calculate_tcp_ssl_password,
|
||||
"Calculate TCP SSL password: hex(md5(hex(md5(device_id)) + iccid))",
|
||||
py::arg("device_id"),
|
||||
py::arg("iccid")
|
||||
);
|
||||
|
||||
m.def(
|
||||
"decrypt_ota_file",
|
||||
[](const std::string& input_path, const std::string& output_zip_path) {
|
||||
netcore::log_info(std::string("decrypt_ota_file in=") + input_path + " out=" + output_zip_path);
|
||||
return netcore::decrypt_ota_file_impl(input_path, output_zip_path);
|
||||
},
|
||||
py::arg("input_path"),
|
||||
py::arg("output_zip_path"),
|
||||
"Decrypt OTA encrypted file (MAGIC|nonce|ciphertext|tag) to plaintext zip."
|
||||
);
|
||||
|
||||
// Minimal demo: return actions for inner_cmd=41 (manual trigger + ack)
|
||||
m.def("actions_for_inner_cmd", [](int inner_cmd) {
|
||||
py::list actions;
|
||||
|
||||
if (inner_cmd == 41) {
|
||||
// 1) set manual trigger flag
|
||||
{
|
||||
py::dict a;
|
||||
a["type"] = "SET_FLAG";
|
||||
py::dict args;
|
||||
args["name"] = "manual_trigger_flag";
|
||||
args["value"] = true;
|
||||
a["args"] = args;
|
||||
actions.append(a);
|
||||
}
|
||||
|
||||
// 2) enqueue trigger_ack
|
||||
{
|
||||
py::dict a;
|
||||
a["type"] = "ENQUEUE";
|
||||
py::dict args;
|
||||
args["msg_type"] = 2;
|
||||
args["high"] = false;
|
||||
py::dict body;
|
||||
body["result"] = "trigger_ack";
|
||||
args["body"] = body;
|
||||
a["args"] = args;
|
||||
actions.append(a);
|
||||
}
|
||||
}
|
||||
|
||||
return actions;
|
||||
});
|
||||
}
|
||||
135
cpp_ext/decrypt_ota_file.cpp
Normal file
135
cpp_ext/decrypt_ota_file.cpp
Normal file
@@ -0,0 +1,135 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // 支持 std::vector, std::map 等
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <cstring>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <array>
|
||||
#include <openssl/evp.h>
|
||||
#include <algorithm>
|
||||
#include "native_logger.hpp"
|
||||
|
||||
namespace netcore{
|
||||
|
||||
// OTA AEAD format: MAGIC(7) | nonce(12) | ciphertext(N) | tag(16)
|
||||
constexpr const char* kOtaMagic = "AROTAE1";
|
||||
constexpr size_t kOtaMagicLen = 7;
|
||||
constexpr size_t kGcmNonceLen = 12;
|
||||
constexpr size_t kGcmTagLen = 16;
|
||||
|
||||
// 固定 32-byte AES-256-GCM key(提高被直接查看的成本;不是绝对安全)
|
||||
// 注意:需要与打包端传入的 --aead-key-hex 保持一致。
|
||||
static std::array<uint8_t, 32> ota_key_bytes() {
|
||||
// 简单拆分混淆:key = a XOR b
|
||||
static const std::array<uint8_t, 32> a = {
|
||||
0x92,0x99,0x4d,0x06,0x6f,0xb6,0xa6,0x3d,0x85,0x08,0xbe,0x73,0x5e,0x73,0x4d,0x8a,
|
||||
0x53,0x88,0xe6,0x99,0xfc,0x10,0x29,0xb9,0x16,0x9b,0xe7,0x0c,0x65,0x21,0x1c,0xce
|
||||
};
|
||||
static const std::array<uint8_t, 32> b = {
|
||||
0xcf,0x60,0xa2,0xc2,0x32,0x7a,0x61,0xb0,0x4c,0x8e,0x8a,0x62,0x31,0xc7,0x82,0xff,
|
||||
0xec,0xac,0xa1,0x04,0x2a,0x4d,0xaa,0xf2,0xb0,0x5b,0x39,0x2b,0xf4,0xb3,0xad,0xad
|
||||
};
|
||||
std::array<uint8_t, 32> k{};
|
||||
for (size_t i = 0; i < k.size(); i++) k[i] = static_cast<uint8_t>(a[i] ^ b[i]);
|
||||
return k;
|
||||
}
|
||||
|
||||
static bool read_file_all(const std::string& path, std::vector<uint8_t>& out) {
|
||||
std::ifstream ifs(path, std::ios::binary);
|
||||
if (!ifs) return false;
|
||||
ifs.seekg(0, std::ios::end);
|
||||
std::streampos size = ifs.tellg();
|
||||
if (size <= 0) return false;
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
out.resize(static_cast<size_t>(size));
|
||||
if (!ifs.read(reinterpret_cast<char*>(out.data()), size)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool write_file_all(const std::string& path, const uint8_t* data, size_t len) {
|
||||
std::ofstream ofs(path, std::ios::binary | std::ios::trunc);
|
||||
if (!ofs) return false;
|
||||
ofs.write(reinterpret_cast<const char*>(data), static_cast<std::streamsize>(len));
|
||||
return static_cast<bool>(ofs);
|
||||
}
|
||||
|
||||
bool decrypt_ota_file_impl(const std::string& input_path, const std::string& output_zip_path) {
|
||||
std::vector<uint8_t> in;
|
||||
if (!netcore::read_file_all(input_path, in)) {
|
||||
netcore::log_error(std::string("decrypt_ota_file: read failed: ") + input_path);
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t min_len = kOtaMagicLen + kGcmNonceLen + kGcmTagLen + 1;
|
||||
if (in.size() < min_len) {
|
||||
netcore::log_error("decrypt_ota_file: too short");
|
||||
return false;
|
||||
}
|
||||
if (!std::equal(in.begin(), in.begin() + kOtaMagicLen, reinterpret_cast<const uint8_t*>(kOtaMagic))) {
|
||||
netcore::log_error("decrypt_ota_file: bad magic");
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint8_t* nonce = in.data() + kOtaMagicLen;
|
||||
const uint8_t* ct_and_tag = in.data() + kOtaMagicLen + kGcmNonceLen;
|
||||
const size_t ct_and_tag_len = in.size() - (kOtaMagicLen + kGcmNonceLen);
|
||||
if (ct_and_tag_len <= kGcmTagLen) {
|
||||
netcore::log_error("decrypt_ota_file: no ciphertext");
|
||||
return false;
|
||||
}
|
||||
const size_t ciphertext_len = ct_and_tag_len - kGcmTagLen;
|
||||
const uint8_t* ciphertext = ct_and_tag;
|
||||
const uint8_t* tag = ct_and_tag + ciphertext_len;
|
||||
|
||||
std::vector<uint8_t> plain(ciphertext_len);
|
||||
int out_len1 = 0;
|
||||
int out_len2 = 0;
|
||||
|
||||
EVP_CIPHER_CTX* ctx = EVP_CIPHER_CTX_new();
|
||||
if (!ctx) {
|
||||
netcore::log_error("decrypt_ota_file: EVP_CIPHER_CTX_new failed");
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ok = false;
|
||||
auto key = ota_key_bytes();
|
||||
|
||||
do {
|
||||
if (1 != EVP_DecryptInit_ex(ctx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr)) {
|
||||
netcore::log_error("decrypt_ota_file: DecryptInit failed");
|
||||
break;
|
||||
}
|
||||
if (1 != EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, static_cast<int>(kGcmNonceLen), nullptr)) {
|
||||
netcore::log_error("decrypt_ota_file: set ivlen failed");
|
||||
break;
|
||||
}
|
||||
if (1 != EVP_DecryptInit_ex(ctx, nullptr, nullptr, key.data(), nonce)) {
|
||||
netcore::log_error("decrypt_ota_file: set key/iv failed");
|
||||
break;
|
||||
}
|
||||
if (1 != EVP_DecryptUpdate(ctx, plain.data(), &out_len1, ciphertext, static_cast<int>(ciphertext_len))) {
|
||||
netcore::log_error("decrypt_ota_file: update failed");
|
||||
break;
|
||||
}
|
||||
if (1 != EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, static_cast<int>(kGcmTagLen), const_cast<uint8_t*>(tag))) {
|
||||
netcore::log_error("decrypt_ota_file: set tag failed");
|
||||
break;
|
||||
}
|
||||
if (1 != EVP_DecryptFinal_ex(ctx, plain.data() + out_len1, &out_len2)) {
|
||||
netcore::log_error("decrypt_ota_file: final failed (auth tag mismatch?)");
|
||||
break;
|
||||
}
|
||||
const size_t plain_len = static_cast<size_t>(out_len1 + out_len2);
|
||||
if (!netcore::write_file_all(output_zip_path, plain.data(), plain_len)) {
|
||||
netcore::log_error(std::string("decrypt_ota_file: write failed: ") + output_zip_path);
|
||||
break;
|
||||
}
|
||||
ok = true;
|
||||
} while (false);
|
||||
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return ok;
|
||||
}
|
||||
}
|
||||
7
cpp_ext/decrypt_ota_file.hpp
Normal file
7
cpp_ext/decrypt_ota_file.hpp
Normal file
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace netcore{
|
||||
bool decrypt_ota_file_impl(const std::string& input_path, const std::string& output_zip_path);
|
||||
}
|
||||
113
cpp_ext/msg_handler.cpp
Normal file
113
cpp_ext/msg_handler.cpp
Normal file
@@ -0,0 +1,113 @@
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include "native_logger.hpp"
|
||||
#include "msg_handler.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
using json = nlohmann::json;
|
||||
|
||||
namespace netcore {
|
||||
// 打包 TCP 数据包
|
||||
py::bytes make_packet(int msg_type, py::dict body_dict) {
|
||||
netcore::log_debug(std::string("make_packet msg_type=") + std::to_string(msg_type));
|
||||
// 1) 将 py::dict 转为 JSON 字符串
|
||||
json j = netcore::py_dict_to_json(body_dict);
|
||||
std::string body_str = j.dump();
|
||||
|
||||
// 2) 计算 body_len 和 checksum
|
||||
uint32_t body_len = body_str.size();
|
||||
uint32_t checksum = body_len + msg_type;
|
||||
|
||||
// 3) 打包头部(大端序)
|
||||
std::vector<uint8_t> packet;
|
||||
packet.reserve(12 + body_len);
|
||||
|
||||
// body_len (big-endian, 4 bytes)
|
||||
packet.push_back((body_len >> 24) & 0xFF);
|
||||
packet.push_back((body_len >> 16) & 0xFF);
|
||||
packet.push_back((body_len >> 8) & 0xFF);
|
||||
packet.push_back(body_len & 0xFF);
|
||||
|
||||
// msg_type (big-endian, 4 bytes)
|
||||
packet.push_back((msg_type >> 24) & 0xFF);
|
||||
packet.push_back((msg_type >> 16) & 0xFF);
|
||||
packet.push_back((msg_type >> 8) & 0xFF);
|
||||
packet.push_back(msg_type & 0xFF);
|
||||
|
||||
// checksum (big-endian, 4 bytes)
|
||||
packet.push_back((checksum >> 24) & 0xFF);
|
||||
packet.push_back((checksum >> 16) & 0xFF);
|
||||
packet.push_back((checksum >> 8) & 0xFF);
|
||||
packet.push_back(checksum & 0xFF);
|
||||
|
||||
// 4) 追加 body
|
||||
packet.insert(packet.end(), body_str.begin(), body_str.end());
|
||||
|
||||
netcore::log_debug(std::string("make_packet done bytes=") + std::to_string(packet.size()));
|
||||
return py::bytes(reinterpret_cast<const char*>(packet.data()), packet.size());
|
||||
}
|
||||
|
||||
// 解析 TCP 数据包
|
||||
py::tuple parse_packet(py::bytes data) {
|
||||
// 1) 转换为 bytes view
|
||||
py::buffer_info buf = py::buffer(data).request();
|
||||
if (buf.size < 12) {
|
||||
netcore::log_error(std::string("parse_packet too_short len=") + std::to_string(buf.size));
|
||||
return py::make_tuple(py::none(), py::none());
|
||||
}
|
||||
|
||||
const uint8_t* ptr = static_cast<const uint8_t*>(buf.ptr);
|
||||
|
||||
// 2) 解析头部(大端序)
|
||||
uint32_t body_len = (ptr[0] << 24) | (ptr[1] << 16) | (ptr[2] << 8) | ptr[3];
|
||||
uint32_t msg_type = (ptr[4] << 24) | (ptr[5] << 16) | (ptr[6] << 8) | ptr[7];
|
||||
uint32_t checksum = (ptr[8] << 24) | (ptr[9] << 16) | (ptr[10] << 8) | ptr[11];
|
||||
|
||||
// 3) 校验 checksum(可选,你现有代码不强制校验)
|
||||
// if (checksum != (body_len + msg_type)) {
|
||||
// return py::make_tuple(py::none(), py::none());
|
||||
// }
|
||||
|
||||
// 4) 检查长度
|
||||
uint32_t expected_len = 12 + body_len;
|
||||
if (buf.size < expected_len) {
|
||||
// 半包
|
||||
netcore::log_warn(std::string("parse_packet incomplete got=") + std::to_string(buf.size) +
|
||||
" expected=" + std::to_string(expected_len));
|
||||
return py::make_tuple(py::none(), py::none());
|
||||
}
|
||||
|
||||
// 5) 防御性检查:如果 data 比预期长,说明可能有粘包
|
||||
// (只解析第一个包,忽略多余数据)
|
||||
if (buf.size > expected_len) {
|
||||
netcore::log_warn(std::string("parse_packet concat got=") + std::to_string(buf.size) +
|
||||
" expected=" + std::to_string(expected_len) +
|
||||
" body_len=" + std::to_string(body_len) +
|
||||
" msg_type=" + std::to_string(msg_type));
|
||||
}
|
||||
|
||||
// 6) 提取 body 并解析 JSON
|
||||
std::string body_str(reinterpret_cast<const char*>(ptr + 12), body_len);
|
||||
|
||||
try {
|
||||
json j = json::parse(body_str);
|
||||
py::dict body_dict = netcore::json_to_py_dict(j);
|
||||
return py::make_tuple(py::int_(msg_type), body_dict);
|
||||
} catch (const json::parse_error& e) {
|
||||
// JSON 解析失败,返回 raw(兼容你现有的逻辑)
|
||||
netcore::log_error(std::string("parse_packet json_parse_error: ") + e.what());
|
||||
py::dict raw_dict;
|
||||
raw_dict["raw"] = body_str;
|
||||
return py::make_tuple(py::int_(msg_type), raw_dict);
|
||||
} catch (const std::exception& e) {
|
||||
netcore::log_error(std::string("parse_packet json_parse_error: ") + e.what());
|
||||
py::dict raw_dict;
|
||||
raw_dict["raw"] = body_str;
|
||||
return py::make_tuple(py::int_(msg_type), raw_dict);
|
||||
}
|
||||
}
|
||||
}
|
||||
14
cpp_ext/msg_handler.hpp
Normal file
14
cpp_ext/msg_handler.hpp
Normal file
@@ -0,0 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // 支持 std::vector, std::map 等
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace netcore {
|
||||
|
||||
// 打包 TCP 数据包
|
||||
py::bytes make_packet(int msg_type, py::dict body_dict);
|
||||
// 解包 TCP 数据包
|
||||
py::tuple parse_packet(py::bytes data);
|
||||
}
|
||||
100
cpp_ext/native_logger.cpp
Normal file
100
cpp_ext/native_logger.cpp
Normal file
@@ -0,0 +1,100 @@
|
||||
#include "native_logger.hpp"
|
||||
|
||||
#include <cerrno>
|
||||
#include <cstring>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <time.h>
|
||||
#include <unistd.h>
|
||||
|
||||
namespace netcore {
|
||||
|
||||
static std::mutex g_mu;
|
||||
static int g_fd = -1;
|
||||
static std::string g_path = "netcore.log";
|
||||
static LogLevel g_level = LogLevel::kDebug; //LogLevel::kInfo;
|
||||
|
||||
static const char* level_name(LogLevel lvl) {
|
||||
switch (lvl) {
|
||||
case LogLevel::kError: return "E";
|
||||
case LogLevel::kWarn: return "W";
|
||||
case LogLevel::kInfo: return "I";
|
||||
case LogLevel::kDebug: return "D";
|
||||
default: return "?";
|
||||
}
|
||||
}
|
||||
|
||||
static void ensure_open_locked() {
|
||||
if (g_path.empty()) return;
|
||||
if (g_fd >= 0) return;
|
||||
g_fd = ::open(g_path.c_str(), O_CREAT | O_WRONLY | O_APPEND, 0644);
|
||||
}
|
||||
|
||||
void set_log_file(const std::string& path) {
|
||||
std::lock_guard<std::mutex> lk(g_mu);
|
||||
g_path = path;
|
||||
if (g_fd >= 0) {
|
||||
::close(g_fd);
|
||||
g_fd = -1;
|
||||
}
|
||||
ensure_open_locked();
|
||||
}
|
||||
|
||||
void set_log_level(LogLevel level) {
|
||||
std::lock_guard<std::mutex> lk(g_mu);
|
||||
g_level = level;
|
||||
}
|
||||
|
||||
void log(LogLevel level, const std::string& msg) {
|
||||
std::lock_guard<std::mutex> lk(g_mu);
|
||||
if (static_cast<int>(level) > static_cast<int>(g_level)) return;
|
||||
if (g_path.empty()) return;
|
||||
|
||||
ensure_open_locked();
|
||||
if (g_fd < 0) {
|
||||
// Last resort: stderr (avoid any Python APIs)
|
||||
::write(STDERR_FILENO, msg.c_str(), msg.size());
|
||||
::write(STDERR_FILENO, "\n", 1);
|
||||
return;
|
||||
}
|
||||
|
||||
// Timestamp: epoch milliseconds (simple and cheap)
|
||||
struct timespec ts;
|
||||
clock_gettime(CLOCK_REALTIME, &ts);
|
||||
// long long ms = (long long)ts.tv_sec * 1000LL + ts.tv_nsec / 1000000LL;
|
||||
// 1. 将秒数转换为本地时间结构体 struct tm
|
||||
struct tm *tm_info = localtime(&ts.tv_sec);
|
||||
|
||||
// 2. 准备一个缓冲区来存储时间字符串
|
||||
char buffer[30];
|
||||
|
||||
// 3. 格式化秒的部分
|
||||
// 格式: 年-月-日 时:分:秒
|
||||
strftime(buffer, sizeof(buffer), "%Y-%m-%d %H:%M:%S", tm_info);
|
||||
|
||||
// 4. 计算毫秒部分并追加到字符串中
|
||||
// ts.tv_nsec 是纳秒,除以 1,000,000 得到毫秒
|
||||
char ms_buffer[8];
|
||||
snprintf(ms_buffer, sizeof(ms_buffer), ".%03ld", ts.tv_nsec / 1000000);
|
||||
|
||||
// Build one line to keep writes atomic-ish
|
||||
char head[256];
|
||||
int n = ::snprintf(head, sizeof(head), "[%s%s] [%s] ", buffer, ms_buffer, level_name(level));
|
||||
if (n < 0) n = 0;
|
||||
|
||||
::write(g_fd, head, (size_t)n);
|
||||
::write(g_fd, msg.c_str(), msg.size());
|
||||
::write(g_fd, "\n", 1);
|
||||
}
|
||||
|
||||
void log_debug(const std::string& msg) { log(LogLevel::kDebug, msg); }
|
||||
void log_info (const std::string& msg) { log(LogLevel::kInfo, msg); }
|
||||
void log_warn (const std::string& msg) { log(LogLevel::kWarn, msg); }
|
||||
void log_error(const std::string& msg) { log(LogLevel::kError, msg); }
|
||||
|
||||
} // namespace netcore
|
||||
|
||||
28
cpp_ext/native_logger.hpp
Normal file
28
cpp_ext/native_logger.hpp
Normal file
@@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace netcore {
|
||||
|
||||
enum class LogLevel : int {
|
||||
kError = 0,
|
||||
kWarn = 1,
|
||||
kInfo = 2,
|
||||
kDebug = 3,
|
||||
};
|
||||
|
||||
// Set log file path. If empty, logging is disabled.
|
||||
void set_log_file(const std::string& path);
|
||||
|
||||
// Set minimum log level to write (default: kInfo).
|
||||
void set_log_level(LogLevel level);
|
||||
|
||||
// Log helpers (thread-safe, never calls into Python).
|
||||
void log(LogLevel level, const std::string& msg);
|
||||
void log_debug(const std::string& msg);
|
||||
void log_info(const std::string& msg);
|
||||
void log_warn(const std::string& msg);
|
||||
void log_error(const std::string& msg);
|
||||
|
||||
} // namespace netcore
|
||||
|
||||
24765
cpp_ext/third_party/nlohmann/json.hpp
vendored
Normal file
24765
cpp_ext/third_party/nlohmann/json.hpp
vendored
Normal file
File diff suppressed because it is too large
Load Diff
95
cpp_ext/utils.cpp
Normal file
95
cpp_ext/utils.cpp
Normal file
@@ -0,0 +1,95 @@
|
||||
#include <fstream>
|
||||
#include <cstring>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace netcore {
|
||||
// 辅助函数:将 py::dict 转为 nlohmann::json
|
||||
json py_dict_to_json(py::dict d) {
|
||||
json j;
|
||||
for (auto item : d) {
|
||||
std::string key = py::str(item.first);
|
||||
py::object val = py::reinterpret_borrow<py::object>(item.second);
|
||||
|
||||
if (py::isinstance<py::dict>(val)) {
|
||||
j[key] = py_dict_to_json(py::cast<py::dict>(val));
|
||||
} else if (py::isinstance<py::list>(val)) {
|
||||
py::list py_list = py::cast<py::list>(val);
|
||||
json arr = json::array();
|
||||
for (auto elem : py_list) {
|
||||
py::object elem_obj = py::reinterpret_borrow<py::object>(elem);
|
||||
if (py::isinstance<py::dict>(elem_obj)) {
|
||||
arr.push_back(py_dict_to_json(py::cast<py::dict>(elem_obj)));
|
||||
} else if (py::isinstance<py::int_>(elem_obj)) {
|
||||
arr.push_back(py::cast<int64_t>(elem_obj));
|
||||
} else if (py::isinstance<py::float_>(elem_obj)) {
|
||||
arr.push_back(py::cast<double>(elem_obj));
|
||||
} else {
|
||||
arr.push_back(py::str(elem_obj));
|
||||
}
|
||||
}
|
||||
j[key] = arr;
|
||||
} else if (py::isinstance<py::int_>(val)) {
|
||||
j[key] = py::cast<int64_t>(val);
|
||||
} else if (py::isinstance<py::float_>(val)) {
|
||||
j[key] = py::cast<double>(val);
|
||||
} else if (py::isinstance<py::bool_>(val)) {
|
||||
j[key] = py::cast<bool>(val);
|
||||
} else if (val.is_none()) {
|
||||
j[key] = nullptr;
|
||||
} else {
|
||||
j[key] = py::str(val);
|
||||
}
|
||||
}
|
||||
return j;
|
||||
}
|
||||
|
||||
// 辅助函数:将 nlohmann::json 转为 py::dict
|
||||
py::dict json_to_py_dict(const json& j) {
|
||||
py::dict d;
|
||||
if (j.is_object()) {
|
||||
for (auto& item : j.items()) {
|
||||
std::string key = item.key();
|
||||
json val = item.value();
|
||||
|
||||
if (val.is_object()) {
|
||||
d[py::str(key)] = json_to_py_dict(val);
|
||||
} else if (val.is_array()) {
|
||||
py::list py_list;
|
||||
for (auto& elem : val) {
|
||||
if (elem.is_object()) {
|
||||
py_list.append(json_to_py_dict(elem));
|
||||
} else if (elem.is_number_integer()) {
|
||||
py_list.append(py::int_(elem.get<int64_t>()));
|
||||
} else if (elem.is_number_float()) {
|
||||
py_list.append(py::float_(elem.get<double>()));
|
||||
} else if (elem.is_boolean()) {
|
||||
py_list.append(py::bool_(elem.get<bool>()));
|
||||
} else if (elem.is_null()) {
|
||||
py_list.append(py::none());
|
||||
} else {
|
||||
py_list.append(py::str(elem.get<std::string>()));
|
||||
}
|
||||
}
|
||||
d[py::str(key)] = py_list;
|
||||
} else if (val.is_number_integer()) {
|
||||
d[py::str(key)] = py::int_(val.get<int64_t>());
|
||||
} else if (val.is_number_float()) {
|
||||
d[py::str(key)] = py::float_(val.get<double>());
|
||||
} else if (val.is_boolean()) {
|
||||
d[py::str(key)] = py::bool_(val.get<bool>());
|
||||
} else if (val.is_null()) {
|
||||
d[py::str(key)] = py::none();
|
||||
} else {
|
||||
d[py::str(key)] = py::str(val.get<std::string>());
|
||||
}
|
||||
}
|
||||
}
|
||||
return d;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
15
cpp_ext/utils.hpp
Normal file
15
cpp_ext/utils.hpp
Normal file
@@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // 支持 std::vector, std::map 等
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <string>
|
||||
|
||||
namespace py = pybind11;
|
||||
using json = nlohmann::json;
|
||||
|
||||
namespace netcore {
|
||||
|
||||
json py_dict_to_json(py::dict d);
|
||||
py::dict json_to_py_dict(const json& j);
|
||||
}
|
||||
184
design_doc/algo.md
Normal file
184
design_doc/algo.md
Normal file
@@ -0,0 +1,184 @@
|
||||
1. 系统目标
|
||||
# 检测靶纸四角的等腰直角三角形标记(每个角一个)
|
||||
# 计算激光落点在靶面上的二维偏移(厘米)
|
||||
# 通过PnP算法估算靶面到相机的距离(米)
|
||||
|
||||
2. 核心算法流程
|
||||
2.1 三角形检测 (detect_triangle_markers)
|
||||
采用多策略级联保证鲁棒性:
|
||||
图像输入 → 多阈值策略 → 候选三角形过滤 → 四点匹配
|
||||
检测策略(按优先级):
|
||||
|
||||
1.全局Otsu二值化(最快,~10ms)
|
||||
2.自适应阈值(多种block size,光照不均时)
|
||||
3.ROI局部阈值(候选不足3个时,分象限独立处理)
|
||||
4.Black-Hat形态学增强(仍不足时,突出暗色标记)
|
||||
|
||||
三角形几何验证:
|
||||
# 必须是直角三角形(检查勾股定理,容差20%)
|
||||
# 两直角边长度差<20%
|
||||
# 内部像素足够暗(灰度≤130,暗像素比例≥30%)
|
||||
# 与周围背景对比度≥15灰度级
|
||||
|
||||
四点匹配算法:
|
||||
# 从候选三角形中枚举所有4点组合
|
||||
# 计算四边形评分:(对角比-1)*3 + (水平比-1) + (垂直比-1) + (边长偏差)*2
|
||||
# 选择评分最低的组合作为四角标记
|
||||
|
||||
2.2 单应性落点计算 (homography_calibration)
|
||||
建立图像坐标系 → 靶面坐标系(二维平面)的透视变换
|
||||
|
||||
将激光点像素坐标映射到靶面坐标(厘米)
|
||||
|
||||
使用RANSAC提高鲁棒性(阈值1像素)
|
||||
|
||||
2.3 PnP距离估计 (pnp_distance_meters)
|
||||
已知四个标记点的三维坐标(x,y,z,单位cm)
|
||||
|
||||
通过solvePnP求解相机外参(旋转+平移)
|
||||
|
||||
距离 = ‖平移向量‖ / 100(转换为米)
|
||||
|
||||
3. 关键优化策略
|
||||
3.1 多路径投票
|
||||
同一图像区域被不同二值化方法检测到时,path_votes++
|
||||
|
||||
选择投票数高的候选,提高检测可信度
|
||||
|
||||
3.2 早退机制
|
||||
候选≥3个 且 覆盖3个以上象限 → 停止更多阈值尝试
|
||||
|
||||
大幅降低嵌入式设备计算开销
|
||||
|
||||
3.3 3点补全机制
|
||||
当只检测到3个角时,通过仿射变换估算第4个角位置
|
||||
|
||||
公式:P_missing = M_inv @ [x_target, y_target, 1]
|
||||
|
||||
3.4 图像缩放
|
||||
默认缩放到0.5倍进行检测(由config控制)
|
||||
|
||||
坐标还原时乘以inv_scale,保持与标定矩阵一致
|
||||
|
||||
4. 数据流示例
|
||||
python
|
||||
输入:
|
||||
- img_rgb: H×W×3 图像
|
||||
- laser_xy: (x_px, y_px) 激光点像素坐标
|
||||
- marker_positions: {0:[0,0,0], 1:[0,30,0], 2:[30,30,0], 3:[30,0,0]} # 4角3D坐标(cm)
|
||||
|
||||
输出:
|
||||
{
|
||||
"ok": True,
|
||||
"dx_cm": 2.5, # 靶面X偏移(cm,向右为正)
|
||||
"dy_cm": -3.2, # 靶面Y偏移(cm,向上为正)
|
||||
"distance_m": 5.43, # 相机到靶面距离(米)
|
||||
"offset_method": "triangle_homography",
|
||||
"distance_method": "pnp_triangle"
|
||||
}
|
||||
5. 鲁棒性设计
|
||||
5.1 参数自适应
|
||||
从config.py动态读取所有阈值(可在线调整)
|
||||
|
||||
三角形边长范围、灰度阈值、对比度要求等均可配置
|
||||
|
||||
5.2 异常处理
|
||||
角点退化检测(距离<3像素判定为重复)
|
||||
|
||||
NaN/Inf校验(单应性矩阵、偏移量、距离)
|
||||
|
||||
距离合理性检查(0.3~20米)
|
||||
|
||||
5.3 降级策略
|
||||
PnP失败 → 只输出偏移,距离置None
|
||||
|
||||
4角检测失败 → 尝试3角补全
|
||||
|
||||
快速路径失败 → CLAHE增强兜底(可选)
|
||||
|
||||
6. 性能特点
|
||||
CPU友好:默认Otsu单次处理,多数场景10-30ms完成检测
|
||||
|
||||
内存可控:最大候选数截断(默认10个),避免组合爆炸
|
||||
|
||||
嵌入式适配:支持图像缩放、早退机制降低计算量
|
||||
|
||||
7. 局限性
|
||||
依赖四个等腰直角三角形(需靶纸特殊设计)
|
||||
|
||||
要求三角形内部足够暗、与背景有对比度
|
||||
|
||||
单应性假设靶面为平面(实际靶纸可能有轻微起伏)
|
||||
|
||||
这套算法在射击训练系统中作为主要定位手段。
|
||||
|
||||
8. 为了加速单应性的计算,引入了yolo模型,一共做了两个模型,一个为靶纸和黑色三角形一体的识别模型,用于做原照片上快速找到靶纸区域。另一个模型是黑色三角形的模型,用于做靶纸区域再找黑色三角形。但是经过对比发现,引入黑色三角形模型反而更慢。入下面的流程A和流程B:
|
||||
yolo靶纸+传统(流程B) yolo靶纸+yolo黑色三角形(流程A)
|
||||
平均值 646.08 916.4457143
|
||||
标准差 94.61300968 57.40401849
|
||||
|
||||
公共前置(两条路都一样)
|
||||
是否用靶环模型裁 Stage1
|
||||
|
||||
TRIANGLE_YOLO_ROI_ENABLE=True 时:跑 靶环 YOLO,得到全图上的 roi_xyxy,后面的三角形都在 img_work = 全图[roi] 上做(必要时再缩成 img_det 给整图传统分支用)。
|
||||
False 时:roi_xyxy=None,三角形在 整幅相机图 上当 img_work。
|
||||
之后都进入 try_triangle_scoring(img_cv, …, roi_xyxy=…, black_yolo_boxes_work=…)
|
||||
|
||||
在里面先做灰度、v_suppress、锐化、det_scale 缩略图等 prep(与是否黑三角模型无关)。
|
||||
差别从 black_yolo_boxes_work 有没有有效子框列表 开始。
|
||||
|
||||
流程 A:用黑色三角形模型(Stage2 黑三角 YOLO)
|
||||
配置要点:TRIANGLE_BLACK_YOLO_ENABLE=True,且 TRIANGLE_BLACK_TRIANGLE_LOCATE_MODE="yolo",并且 已有 Stage1 裁切(roi_xyxy 不能为 None,否则根本不会跑黑三角 YOLO)。
|
||||
|
||||
步骤概要:
|
||||
|
||||
try_black_triangle_boxes_work
|
||||
|
||||
输入:全图 RGB + Stage1 的 ring_roi_xyxy。
|
||||
在 Stage1 裁切图(与训练一致的 slab)上跑 黑三角 YOLO,得到若干个 子框(black_boxes_work,坐标在 裁切图/work 系)。
|
||||
try_triangle_scoring 内
|
||||
|
||||
若 black_yolo_boxes_work 非空:
|
||||
按配置在 Stage1 全分辨率灰度(或缩略灰度,视 det_scale / TRIANGLE_BLACK_YOLO_PATCH_GRAY_SOURCE)上,对每个子框裁 patch,跑 _extract_triangle_from_yolo_patch(子框内:Otsu → 失败再单次 Adaptive + 轮廓 + 形状/颜色)。
|
||||
median_leg 过滤,再 四点分配 ID。
|
||||
若 ≥3 个(通常 4 个)有效:认为 Stage2 成功,跳过 整幅 Stage1 上的 detect_triangle_markers。
|
||||
若 不足 3 个 且未关 fallback:在 缩略后的整幅 work 灰度上再走 detect_triangle_markers(整图 Otsu + 整图 Adaptive×block_sizes + 各类 fallback),与「不用黑三角模型时的传统主路径」同类。
|
||||
后续
|
||||
|
||||
角点从 det 坐标 ×inv_scale 回到 work,再 +roi 原点 回到全图;单应性、补第 4 点、PnP 等与另一条路相同。
|
||||
耗时上多出来的部分:黑三角 YOLO 推理 + 每个子框一遍传统小流水线(成功时通常 不再付整图 detect_triangle_markers)。
|
||||
|
||||
流程 B:不用黑色三角形模型(纯传统定位三角)
|
||||
典型配置(任一即可达到「不用黑三角模型」的效果):
|
||||
|
||||
TRIANGLE_BLACK_YOLO_ENABLE=False,或
|
||||
TRIANGLE_BLACK_TRIANGLE_LOCATE_MODE="traditional"(即使模型开关开着也不跑黑三角 YOLO),或
|
||||
没有 Stage1 ROI(roi_xyxy is None)时,当前逻辑下 也不会跑 Stage2 黑三角 YOLO。
|
||||
此时 black_yolo_boxes_work=None(或不等价于「有子框」)。
|
||||
|
||||
步骤概要:
|
||||
|
||||
try_triangle_scoring 内
|
||||
不跑 子框 _extract_triangle_from_yolo_patch。
|
||||
直接在 img_det(缩略后的 work) 上调用 detect_triangle_markers:
|
||||
全局 Otsu(若 TRIANGLE_SKIP_GLOBAL_OTSU_EXTRACT_ON_YOLO_ROI 在有 ROI 时可能 不算 Otsu 轮廓,但仍会生成 Otsu 图供后续用);
|
||||
可选 象限 ROI(TRIANGLE_ROI_ENABLED);
|
||||
整图 Adaptive(TRIANGLE_ADAPTIVE_BLOCK_SIZES,例如 (11,));
|
||||
不足再走 放宽 approxPolyDP、BlackHat 等。
|
||||
后面同样是过滤、四点组合/象限分配、单应性、PnP 等。
|
||||
特点:没有黑三角 NPU 时间,也 没有「按框重复 4 次子框传统」;但要在 一整张(缩略)ROI 图 上跑一套更重的 整图 pipeline。
|
||||
|
||||
对照一句话
|
||||
用黑三角 YOLO(流程 A) 不用黑三角 YOLO(流程 B)
|
||||
Stage2
|
||||
黑三角模型给子框 → 子框内 Otsu + 至多一次 Adaptive
|
||||
无 Stage2 模型
|
||||
三角角点从哪来
|
||||
优先 子框传统;不够再 整图 detect_triangle_markers
|
||||
只有 整图 detect_triangle_markers
|
||||
和「全图是否只做 Adaptive」
|
||||
子框 不是只做 Adaptive;整图回退时也与全图路径一致(先 Otsu 等)
|
||||
整图路径 也不是只做 Adaptive
|
||||
靶环 YOLO(Stage1 裁切)在 A/B 里都可以开或关,与「黑三角模型」是独立开关。
|
||||
|
||||
|
||||
102
design_doc/command_record.md
Normal file
102
design_doc/command_record.md
Normal file
@@ -0,0 +1,102 @@
|
||||
|
||||
1. CPP构建命令:
|
||||
|
||||
cd /mnt/d/code/archery/cpp_ext
|
||||
rm -rf build && mkdir build && cd build
|
||||
|
||||
TOOLCHAIN_BIN=/mnt/d/code/MaixCDK/dl/extracted/toolchains/maixcam/host-tools/gcc/riscv64-linux-musl-x86_64/bin
|
||||
PYDEV=/mnt/d/code/shooting/python3_lib_maixcam_musl_3.11.6
|
||||
MAIXCDK=/mnt/d/code/MaixCDK
|
||||
|
||||
cmake .. -G Ninja \
|
||||
-DCMAKE_C_COMPILER="${TOOLCHAIN_BIN}/riscv64-unknown-linux-musl-gcc" \
|
||||
-DCMAKE_CXX_COMPILER="${TOOLCHAIN_BIN}/riscv64-unknown-linux-musl-g++" \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_C_FLAGS="-mcpu=c906fdv -march=rv64imafdcv0p7xthead -mcmodel=medany -mabi=lp64d" \
|
||||
-DCMAKE_CXX_FLAGS="-mcpu=c906fdv -march=rv64imafdcv0p7xthead -mcmodel=medany -mabi=lp64d" \
|
||||
-DPY_INCLUDE_DIR="${PYDEV}/include/python3.11" \
|
||||
-DPY_LIB="${PYDEV}/lib/libpython3.11.so" \
|
||||
-DPY_EXT_SUFFIX=".cpython-311-riscv64-linux-gnu.so" \
|
||||
-DMAIXCDK_PATH="${MAIXCDK}"
|
||||
|
||||
ninja
|
||||
|
||||
|
||||
2. Maixvision 直接跑项目的时候,是复制到板子上的这个目录:/tmp/maixpy_run
|
||||
|
||||
3. 4g 模块的终端测试方法:
|
||||
3.1 一个窗口 ssh 到maixcam的板子上之后,通过 printf 输入命令到 /dev/ttyS2, 然后另外一个窗口通过 cat /dev/ttyS2 输出
|
||||
# 1. 确保 PDP 激活
|
||||
printf 'AT+CGPADDR=1\r\n' > /dev/ttyS2
|
||||
# 2. 开启日志监听(另一个 SSH 窗口)
|
||||
cat /dev/ttyS2
|
||||
# 3. 发送下载命令(原窗口)
|
||||
printf 'AT+MHTTPDLFILE="http://static.shelingxingqiu.com/shoot/v1/main.py","downloaded.py",5120\r\n' > /dev/ttyS2
|
||||
|
||||
4. wifi的启动条件,在 /boot 目录下,看看是否有 wifi.sta 和 wifi.ssid, wifi.pass 这些文件。其中 wifi.sta 是开关文件。
|
||||
如果没有了它就不会启动wifi流程。具体的wifi流程 由 /etc/init.d/S30wifi 控制。它会判断 wifi.sta 是否存在,然后是否启动wifi,还是启动热点。
|
||||
|
||||
5. 给自己的程序打包到基础镜像中,参考:https://wiki.sipeed.com/maixpy/doc/zh/pro/compile_os.html
|
||||
5.1. 按照链接中的步骤,去github上获取了基础镜像,这次使用的是 v4.12.4,把Assets中的下面几样东西下载下来,我是在windows的wsl中执行的,注意,
|
||||
假如是在windows中下载的文件,在wsl中编译会很慢,所以我采用的是直接在wsl中下载,放到wsl的自己的文件系统中。
|
||||
1)maixcam-2025-12-31-maixpy-v4.12.4.img.xz
|
||||
2)maixcam_builtin_files.tar.xz
|
||||
3)MaixPy-4.12.4-py3-none-any.whl
|
||||
4)Source code(zip)
|
||||
5.2. 把自己的文件放到 buildtin_files中:
|
||||
1)我把项目文件目录 t11 放到了 maixcam_builtin_files\maixapp\apps 这个目录下。
|
||||
2)为了能让它自启动,我把 auto_start.txt 放到了 maixcam_builtin_files\maixapp 这个目录下。
|
||||
|
||||
5.3. 然后在解压后的源码中找到tools/os目录下 /home/saga/maixcam/MaixPy-4.12.4/tools/os/maixcam
|
||||
执行
|
||||
export MAIXCDK_PATH=/home/saga/maixcam/MaixCDK
|
||||
编译:
|
||||
./gen_os.sh ../../../../../maixcam/maixcam-2025-12-31-maixpy-v4.12.4.img ../../../../../maixcam/MaixPy-4.12.4-py3-none-any.whl ../../../../../maixcam/maixcam_builtin_files 0 maixcam
|
||||
注意,在编译过程中,也会去 github 下载内容,所以需要打开梯子。
|
||||
5.4. 等待编译完成,会编译成镜像文件,然后根据 https://wiki.sipeed.com/hardware/zh/maixcam/os.html 这个指引来烧录系统。
|
||||
5.5. 烧录完系统后,需要安装 runtime, 可以按照 https://wiki.sipeed.com/maixpy/doc/zh/README_no_screen.html 这个来升级运行库,或者直接在 Maixvision 中链接的时候安装 runtime。
|
||||
5.6. 安装 runtime 之后,重启,我们的系统就会自己启动起来了。
|
||||
|
||||
遇到问题:
|
||||
/mnt/d/code/shooting/compile_maixcam/MaixPy-4.12.4/MaixPy-4.12.4/tools/os/maixcam/fuse2fs: error while loading shared libraries: libfuse.so.2: cannot open shared object file: No such file or directory
|
||||
解决办法:
|
||||
安装 libfuse2
|
||||
sudo apt update
|
||||
sudo apt install libfuse2
|
||||
|
||||
遇到问题:
|
||||
python 缺少 yaml
|
||||
解决办法:
|
||||
pip install pyyaml
|
||||
|
||||
遇到问题:
|
||||
./build_all.sh: line 56: maixtool: command not found
|
||||
解决办法:
|
||||
pip install maixtool
|
||||
|
||||
遇到问题:
|
||||
./update_img.sh: line 80: mcopy: command not found
|
||||
解决办法:
|
||||
sudo apt update
|
||||
sudo apt install mtools
|
||||
|
||||
6. 相机标定:
|
||||
然后在板子上跑 test 目录下的 test_camera_rtsp.py ,让相机启动了一个服务,然后在电脑上接收这个视频流,并且跑opencv 内置的标定程序:
|
||||
set OPENCV_FFMPEG_CAPTURE_OPTIONS="rtsp_transport;tcp"
|
||||
opencv_interactive-calibration -t=chessboard -w=9 -h=6 -sz=0.025 -v="http://192.168.1.81:8000/stream" 2>nul
|
||||
|
||||
|
||||
7. 生成训练图片:在test目录下,执行以下命令。注意,其中 D:\code\shooting\target_photo\write.png 是靶纸的图片。
|
||||
D:\data\test_target_photo 是用来叠加的背景图
|
||||
|
||||
7.1 生成靶纸及黑色三角形的截图的图片,带动动,但1.12的外框
|
||||
bak
|
||||
python .\synth_compose_yolo.py --perspective 0.04 --perspective-prob 0.8 --color-jitter 0.6 --bg-dir D:\data\test_target_photo --fg D:\code\shooting\target_photo\write.png --out ./synth_out --class-name triangle --zip ./maix_dataset.zip --num 60 --triangles-json archery_triangles_default.json --format voc --stage2-crop --stage2-pad-min 0.03 --stage2-pad-max 0.18 --motion-prob 0.9 --motion-kernel-max 8 --blur-max 0 --triangle-bbox-pad-frac 0.12
|
||||
|
||||
bak_2
|
||||
python synth_keypoints_right_angle.py --bg-dir D:\data\test_target_photo --fg D:\code\shooting\target_photo\write.png --triangles-json archery_triangles_default.json --out ./synth_out --num 1000 --offscreen-shift-prob 0.3 --offscreen-shift-frac 0.4 --offscreen-min-visible 1 --stage2-crop --stage2-pad-min 0.03 --stage2-pad-max 0.18 --motion-prob 0.9 --motion-kernel-max 8 --blur-max 0 --perspective-mode planar --yaw-max-deg 10 --pitch-max-deg 8 --roll-max-deg 4 --planar-focal-frac 1.45 --perspective-prob 0.4
|
||||
|
||||
python synth_keypoints_right_angle.py --bg-dir D:\data\test_target_photo --fg D:\code\shooting\target_photo\write.png --triangles-json archery_triangles_default.json --out ./synth_out --num 1000 --offscreen-shift-prob 0.3 --offscreen-shift-frac 0.4 --offscreen-min-visible 1 --stage2-crop --stage2-pad-min 0.03 --stage2-pad-max 0.18 --motion-prob 1.0 --motion-kernel-max 8 --blur-max 0 --perspective-mode planar --yaw-max-deg 10 --pitch-max-deg 8 --roll-max-deg 4 --planar-focal-frac 1.45 --perspective-prob 0.4
|
||||
|
||||
|
||||
python pose_pixel_metrics.py --model D:\code\archery\runs\pose\runs\pose\target_pose_train\weights\best.pt --data D:\code\archery\datasets\dataset_pose.yaml --imgsz 640
|
||||
41
design_doc/debug.md
Normal file
41
design_doc/debug.md
Normal file
@@ -0,0 +1,41 @@
|
||||
1. 问题描述:开机失败,一直遇到Traceback (most recent call last):
|
||||
File "/tmp/maixpy_run/main.py", line 525, in <module>
|
||||
cmd_str()
|
||||
File "/tmp/maixpy_run/main.py", line 102, in cmd_str
|
||||
camera_manager.init_camera(640, 480)
|
||||
File "/tmp/maixpy_run/camera_manager.py", line 59, in init_camera
|
||||
self._camera = camera.Camera(width, height)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
RuntimeError: : Runtime error: mmf vi init failed
|
||||
解决方案:
|
||||
根据过往经验,极有可能是摄像头的接线有问题。因为在测试环境,摄像头是通过一个24针转22针的线出来的,然后再通过一个接线中继,连接到一个22针
|
||||
的fpc线到Maixcam。接线中继如果是24针的,多了两针,需要选好一边然后对连。但这里很容易出错或者松动。可以先用摄像头本身的金色接线直接接到
|
||||
Maixcam,然后跑test目录下的test_cammera.py,看看能不能正常启动,如果正常,就确定是中继接线的问题。
|
||||
|
||||
2. 问题描述:202609 批次的拓展版,在连接 202601 批次的电源板,或者不链接电源板的时候,开机后不久,出错,程序退出,日志是:
|
||||
[v1.2.10] [INFO] network.py:1078 - [NET] TCP主线程启动
|
||||
[v1.2.10] [INFO] network.py:406 - [NET] WiFi不可用或无法连接服务器,使用4G网络
|
||||
[v1.2.10] [INFO] network.py:475 - 连接到服务器,使用4G...
|
||||
[v1.2.10] [INFO] network.py:527 - [4G-TCP] AT+MIPCLOSE=2 response:
|
||||
OK
|
||||
+MIPCLOSE: 2
|
||||
-- [E] read failed
|
||||
Trigger signal, code:SIGSEGV(11)!
|
||||
maix multi-media driver released.
|
||||
ISP Vipipe(0) Free pa(0x8a52c000) va(0x0x3fbeb5e000)
|
||||
program exit failed. exit code: 1.
|
||||
|
||||
解决方案:
|
||||
从日志看,就是开始发送登录信息之后就崩溃了。出发了底层的read failed。经过排查,是一定要插上电源板的数据连线,以及电源板要插上电池。这个应该是登录时需要读电源电压数据。后面我们已经优化了日志,而且增加了对ina226的试探,但发现ina226不存在的时候,就直接返回电压和电流为0.0。而且,一定要注意,在新配套的电源板和核心板上面,才能正常读到电流和电压。
|
||||
|
||||
3. a)问题描述:202609 批次的拓展版,有一块maixcam的蓝灯常亮,询问maixcam的人,他们觉得应该是卡没有插好。但是拓展版上的激光口挡住了数据卡的出口,
|
||||
没法拔出检查,
|
||||
解决方案:需要做拓展版的公司(深链鑫创)在做好板子之后,确定系统能正常启动
|
||||
|
||||
b)问题描述:2022609 批次的拓展板,有一次maixcam的蓝灯亮的时候很长,不会闪烁,后面把sd卡插进去一点,又恢复正常了,初步怀疑是射箭时没有缓冲,
|
||||
导致了sd 卡被撞松了
|
||||
|
||||
4. 问题描述:4G模块不可用,模块的绿灯没有闪亮
|
||||
解决方案:有这样的一种情况,就是4G模块的天线,触碰到了旁边的电容,导致短路,所以模块启动失败。需要保证电容和天线的金属头不会触碰
|
||||
5.
|
||||
|
||||
276
design_doc/solution_record.md
Normal file
276
design_doc/solution_record.md
Normal file
@@ -0,0 +1,276 @@
|
||||
1. 4G OTA 下载的时候,为什么使用十六进制下载,读取 URC 事件?
|
||||
因为使用二进制下载的时候,经常会出现错误,并且会失败?然后最稳定传输的办法,是每次传输的时候,是分块,而且每次分块都要“删/建”http实例。推测原因是因为我们现在是直接传输文件的源代码,代码中含有了一些字符串可能和 AT指令重复,导致了 AT 模块在解释的时候出错。而使用 16 进制的方式,可以避免这个问题。因为十六进制直接把数据先转成了字符串,然后在设备端再把字符串转成数据,这样就不可能出现 AT的指令,从而减少了麻烦。
|
||||
2. 4G OTA 下载的时候,为什么不用 AT 模块里 HTTPDLFILE 的指令?
|
||||
因为在测试中发现,使用 HTTPDLFILE,其实是下载到了 4G 模块内部,需要重新从模块内部转到存储卡,而且 4G 模块的存储较小,大概只有 40k,所以还需要分块来下载和转存,比较麻烦,于是最终使用了使用读取串口事件的模式。
|
||||
3. 4G OTA 下载的时候,为什么不用 AT 模块里 HTTPREAD 的指令?
|
||||
因为之前测试发现,READ模式其实是需要多步:
|
||||
3.1. AT+MHTTPCREATE
|
||||
3.2. AT+MHTTPCFG
|
||||
3.3. AT+MHTTPREQUEST
|
||||
3.4. AT+MHTTPREAD
|
||||
它其实也是把数据下载到 4g 模块的缓存里,然后再从缓存里读取出来。所以也是比较繁琐的,还不如 HTTPDLFILE 简单。
|
||||
4. WiFi OTA 流程(ota_manager.handle_wifi_and_update())
|
||||
* 解析 ota_url 得到 host:port
|
||||
* 调用 network_manager.connect_wifi(ssid, password, verify_host=host, verify_port=port, persist=True)
|
||||
* 只有“能连上 WiFi 且能访问 OTA host:port”才会把新凭证保留在 /boot
|
||||
* 连接成功后开始下载 OTA 文件(download_file())
|
||||
* 下载成功则 apply_ota_and_reboot()
|
||||
5. TCP 通信
|
||||
1) 平时 TCP 通信主流程(network_manager.tcp_main())
|
||||
外层无限循环:一直尝试保持与服务器的 TCP 会话。
|
||||
每轮开始:
|
||||
如果 OTA 正在进行:暂停(避免抢占资源/串口)。
|
||||
connect_server():建立 TCP 连接(自动选 WiFi 或 4G)。
|
||||
发送“登录包”(msg_type=1),等待服务器返回“登录成功”。
|
||||
登录成功后进入内层循环:
|
||||
接收数据:
|
||||
WiFi:非阻塞 recv();没数据返回 b"";有数据进入缓冲区拼包解析。
|
||||
4G:从 ATClient 的队列 pop_tcp_payload() 取数据。
|
||||
处理命令/ACK:
|
||||
登录响应、心跳 ACK、OTA 命令、关机命令、日志上传命令等。
|
||||
发送业务队列:
|
||||
从高优/普通队列取 1 条,发送失败会放回队首,并断线重连(不再丢消息)。
|
||||
发送心跳:
|
||||
按 HEARTBEAT_INTERVAL 发心跳包。
|
||||
心跳失败会计数(当前为连续失败到阈值才重连)。
|
||||
任何发送/接收致命失败:
|
||||
关闭 socket/断开连接 → 跳出内层循环 → 外层等待一会儿后重新 connect_server() → 重新登录。
|
||||
6. “WiFi 连接/验证”
|
||||
TCP 连接建立与网络选择(connect_server() / select_network())
|
||||
* select_network():WiFi 优先,但要求:
|
||||
is_wifi_connected() 为 True(系统层面有 WiFi IP 或 Maix WLAN connected)
|
||||
且能连到 TCP 服务器 SERVER_IP:SERVER_PORT
|
||||
否则回退到 4G
|
||||
* connect_server():
|
||||
若已有连接:WiFi 会做 _check_wifi_connection() 轻量检查;4G 直接认为 OK(由 AT 层维护)。
|
||||
否则按网络类型走:
|
||||
WiFi:创建 socket → connect → setblocking(False)(接收用非阻塞)
|
||||
4G:AT+MIPOPEN 建链
|
||||
WiFi 链接(connect_wifi())
|
||||
当前 connect_wifi() 的关键特点是:必须让 /etc/init.d/S30wifi restart 真正用新 SSID 去连,所以会临时写 /boot/wifi.ssid 和 /boot/wifi.pass,失败自动回滚。
|
||||
流程是:
|
||||
(1) 备份旧配置
|
||||
* /boot/wifi.ssid、/boot/wifi.pass
|
||||
* /etc/wpa_supplicant.conf(尽量备份)
|
||||
(2) 写入新凭证
|
||||
* 把新 ssid/pass 写到 /boot/*
|
||||
-(同时尽量写 /etc/wpa_supplicant.conf,但不强依赖)
|
||||
(3) 重启 WiFi 服务:/etc/init.d/S30wifi restart
|
||||
(4) 等待获取 IP(默认 20 秒,可调)
|
||||
(5) 验证可用性,连到 verify_host:verify_port
|
||||
(6) 成功
|
||||
* persist=True:保留 /boot/*(持久化)
|
||||
* persist=False:回滚 /boot/* 到旧值(不重启,当前连接仍可继续)
|
||||
(7) 失败
|
||||
* 回滚 /boot/* + 回滚 /etc/wpa_supplicant.conf(如果有备份)
|
||||
* 再 S30wifi restart 恢复旧网络
|
||||
* 返回错误
|
||||
|
||||
7. 日志上传(inner_cmd == 43),当前只支持 wifi 上传日志
|
||||
命令带 ssid/password/url 时:
|
||||
* 若 WiFi 未连接:先 connect_wifi(..., verify_host=upload_host, verify_port=upload_port, persist=True)
|
||||
上传内容:
|
||||
* sync # 把日志从内存同步到文件
|
||||
* 快照 app.log* 到 /tmp staging
|
||||
* 打包成 tar.gz(默认)或 zip
|
||||
* 以 multipart/form-data 的 file 字段 POST 到 url
|
||||
|
||||
8. 自动关机:
|
||||
hardware中设定了开停表,然后再增加了获取idle的时间。
|
||||
自动关机的时机: 超过配置的idle时长,
|
||||
禁止自动关机的情况:1.校准中,2.OTA中
|
||||
重启计时的时机:1.校准完成,2.命令触发射箭,3.真实触发射箭,4.初始化完成
|
||||
9. Wifi网络监控:
|
||||
有两次发现wifi网络下,有些消息发送很慢,但具体是什么缘故还不清楚,现在增加了wifi网络下的检测,并一旦发现wifi的网络质量差,就会切换到4G。
|
||||
WiFi 连接成功
|
||||
↓
|
||||
启动后台监测线程
|
||||
↓
|
||||
每 5 秒循环:
|
||||
测量 RTT (1 样本,600ms timeout)
|
||||
获取 RSSI
|
||||
更新缓存
|
||||
判断是否差:
|
||||
- RTT >= 600ms → 差
|
||||
- RTT >= 350ms 且 RSSI <= -80dBm → 差
|
||||
↓
|
||||
如果质量差:
|
||||
快速重试2次,如果其中任意一次网络恢复了,继续使用wifi。否则,
|
||||
调用 _switch_to_4g_due_to_poor_wifi()
|
||||
关闭 WiFi socket
|
||||
重置连接状态
|
||||
尝试切换到 4G
|
||||
↓
|
||||
上层检测到连接断开:
|
||||
重新 connect_server() → 自动选择 4G
|
||||
|
||||
10. 现在使用的相机,其实是支持更大的分辨率的,比如说1920*1280,但是由于我们的图像处理,拍照处理之后很容易触发OOM。
|
||||
|
||||
11. 环数计算流程:
|
||||
现在设备侧的目标是:算出箭点相对靶心的偏移(dx,dy),单位是物理厘米(cm),然后把它作为 x,y 上报给后端;后端再去算环。
|
||||
设备侧本身不直接算环数,它算的是偏移与距离,并上报。
|
||||
|
||||
算法流程(一次射箭从触发到上报)
|
||||
1) 触发后取一帧图
|
||||
在 process_shot() 里读取相机帧并调用 analyze_shot(frame)
|
||||
2) 确定激光点(laser_point)
|
||||
|
||||
analyze_shot() 第一步先确定激光点 (x,y)(像素坐标):
|
||||
|
||||
硬编码:config.HARDCODE_LASER_POINT=True → 用 laser_manager.laser_point
|
||||
已校准:laser_manager.has_calibrated_point() → 用校准值
|
||||
动态模式:先 detect_circle_v3(frame, None) 粗估距离,再根据距离反推激光点
|
||||
代码在:
|
||||
|
||||
if config.HARDCODE_LASER_POINT:
|
||||
...
|
||||
elif laser_manager.has_calibrated_point():
|
||||
...
|
||||
else:
|
||||
_, _, _, _, best_radius1_temp, _ = detect_circle_v3(frame, None)
|
||||
distance_m_first = estimate_distance(best_radius1_temp) ...
|
||||
laser_point = laser_manager.calculate_laser_point_from_distance(distance_m_first)
|
||||
3) 优先走三角形路径(成功就直接用于上报 x/y)
|
||||
如果 config.USE_TRIANGLE_OFFSET=True,先尝试识别靶面四角三角形标记:
|
||||
|
||||
if getattr(config, "USE_TRIANGLE_OFFSET", False):
|
||||
K, dist_coef, pos = _get_triangle_calib()
|
||||
img_rgb = image.image2cv(frame, False, False)
|
||||
tri = try_triangle_scoring(img_rgb, (x, y), pos, K, dist_coef, ...)
|
||||
if tri.get("ok"):
|
||||
return {... "dx": tri["dx_cm"], "dy": tri["dy_cm"], "distance_m": tri.get("distance_m"), ...}
|
||||
这一步里 try_triangle_scoring() 做了两件事(都在 triangle_target.py):
|
||||
|
||||
单应性(homography):把激光点从图像坐标映射到靶面坐标系,得到(dx,dy)(cm)
|
||||
PnP:用识别到的角点与相机标定,估算 相机到靶的距离 distance_m
|
||||
关键代码:
|
||||
|
||||
ok_h, tx, ty, _H = homography_calibration(...)
|
||||
out["dx_cm"] = tx
|
||||
out["dy_cm"] = -ty
|
||||
out["distance_m"] = dist_m
|
||||
out["distance_method"] = "pnp_triangle"
|
||||
注意:这里 dy_cm 取了负号,是为了和现网约定一致(laser_manager.compute_laser_position 的坐标方向)。
|
||||
|
||||
4) 三角形失败 → 回退圆形/椭圆靶心检测(兜底)
|
||||
如果三角形不可用或识别失败,就走传统靶心检测:
|
||||
|
||||
detect_circle_v3(frame, laser_point) 找黄心/红心、半径、椭圆参数
|
||||
用 laser_manager.compute_laser_position() 把像素偏移换算成厘米偏移(dx,dy)
|
||||
在 shoot_manager.py:
|
||||
|
||||
result_img, center, radius, method, best_radius1, ellipse_params = detect_circle_v3(frame, laser_point)
|
||||
if center and radius:
|
||||
dx, dy = laser_manager.compute_laser_position(center, (x, y), radius, method)
|
||||
distance_m = estimate_distance(best_radius1) ...
|
||||
在 laser_manager.compute_laser_position()(核心换算逻辑):
|
||||
|
||||
r = radius * 5
|
||||
target_x = (lx-cx)/r*100
|
||||
target_y = (ly-cy)/r*100
|
||||
return (target_x, -target_y)
|
||||
这里 (像素差)/(radius*5)*100 是你们旧约定下的“像素→厘米”比例模型(并且 y 方向同样取负号)。
|
||||
|
||||
5) 上报数据:把(dx,dy) 作为 x/y 发给后端
|
||||
最终上报发生在 process_shot(),直接把 dx,dy 填到 inner_data["x"],["y"]:
|
||||
|
||||
srv_x = round(float(dx), 4) if dx is not None else 200.0
|
||||
srv_y = round(float(dy), 4) if dy is not None else 200.0
|
||||
inner_data = {
|
||||
"x": srv_x,
|
||||
"y": srv_y,
|
||||
"d": round((distance_m or 0.0) * 100),
|
||||
"m": method if method else "no_target",
|
||||
"offset_method": offset_method,
|
||||
"distance_method": distance_method,
|
||||
...
|
||||
}
|
||||
network_manager.safe_enqueue(...)
|
||||
x,y:物理厘米(cm)
|
||||
d:相机到靶距离(m→cm,乘 100;三角形成功时来自 PnP)
|
||||
m/offset_method/distance_method:标记本次用的算法路径(triangle / yellow / pnp 等)
|
||||
后端收到 x,y 后,再用你之前给的 Go 公式 CalculateRingNumber(x,y,tenRingRadius) 计算环数。
|
||||
|
||||
你现在的“环数计算”实际依赖关系
|
||||
最好路径(快+稳):三角形 → dx,dy(单应性) + distance_m(PnP)
|
||||
兜底路径:圆/椭圆靶心 → dx,dy(基于黄心半径比例/透视校正) + distance_m(黄心半径估距)
|
||||
|
||||
12. 4g模块上传文件:
|
||||
|
||||
Upload images from MaixCam to Qiniu cloud via ML307R 4G module's AT commands. The HTTP body requires multipart/form-data with real CR/LF bytes (0x0D 0x0A) in boundaries.
|
||||
Methods Tried
|
||||
# Method AT Commands Result Root Cause
|
||||
1 Raw binary, no encoding MHTTPCONTENT with raw bytes + length param ERROR at first chunk CR/LF in binary data terminates AT command parser
|
||||
2 Encoding mode 2 (escape) MHTTPCFG="encoding",0,2 + \r\n escapes Server 400 Bad Request Module sends literal text \r\n to server, NOT actual 0x0D 0x0A bytes. Multipart body is garbled
|
||||
3 Encoding mode 1 (hex) MHTTPCFG="encoding",0,1 + hex-encoded data CME ERROR: 650/50 Firmware doesn't properly support hex mode for MHTTPCONTENT
|
||||
4 No chunked mode Skip MHTTPCFG="chunked" CME ERROR: 65 Module requires chunked mode to accept MHTTPCONTENT at all
|
||||
5 Single large MHTTPCONTENT All data in one command (2793 bytes) +MHTTPURC: "err",0,5 (timeout) Possible buffer limit; module hangs then times out
|
||||
6 Per-chunk HTTP instance (OTA style) CREATE→POST→DELETE per chunk Not feasible Each instance = separate HTTP request; Qiniu needs complete body in single POST
|
||||
Conclusion: AT HTTP layer (MHTTPCONTENT) is fundamentally broken for binary uploads.
|
||||
The Solution: Raw TCP Socket (MIPOPEN + MIPSEND)
|
||||
Bypass the AT HTTP layer entirely. Open a raw TCP connection and send a hand-crafted HTTP POST:
|
||||
plaintext
|
||||
AT+MIPCLOSE=3 // Clean up old socket
|
||||
AT+MIPOPEN=3,"TCP","upload.qiniup.com",80 // Raw TCP connection
|
||||
AT+MIPSEND=3,1024 → ">" → [raw bytes] → OK // Binary-safe!
|
||||
AT+MIPSEND=3,1024 → ">" → [raw bytes] → OK
|
||||
AT+MIPSEND=3,766 → ">" → [raw bytes] → OK
|
||||
// Response: +MIPURC: "rtcp",3,<len>,HTTP/1.1 200 OK...
|
||||
AT+MIPCLOSE=3
|
||||
Why it works:
|
||||
MIPSEND enters prompt mode (>) — after the >, the AT parser treats ALL bytes as data, including CR/LF
|
||||
We construct the complete HTTP request ourselves (headers + Content-Length + multipart body) with real CRLF bytes
|
||||
|
||||
Key bug found during integration: _send_chunk() wrapped calls in self.at._cmd_lock, but self.at.send() also acquires the same lock internally — threading.Lock() is not reentrant, causing deadlock. Fixed by removing the outer lock (the network_manager.get_uart_lock() already provides thread safety).Trade-off: UART is locked during the entire upload, so heartbeats pause. For small JPEG files (~2-80KB), this is 5-20 seconds — acceptable if server heartbeat timeout is generous
|
||||
|
||||
|
||||
13. 算环数算法1:「黄心 + 红心」椭圆/圆:主要在 vision.py 的 detect_circle_v3() 里完成:颜色先用 HSV 做掩码,再在轮廓上做面积、圆度筛选,黄圈用椭圆拟合,红圈预先筛成候选,最后用几何关系配对。
|
||||
|
||||
1. 黄色怎么判、范围是什么?
|
||||
图像先转 HSV(cv2.COLOR_RGB2HSV,注意输入是 RGB)。
|
||||
饱和度 S 整体乘 1.1 并限制在 0–255(让黄色更「显」一点)。
|
||||
黄色 inRange(OpenCV HSV,H 多为 0–179):
|
||||
通道 下限 上限
|
||||
H 7 32
|
||||
S 80 255
|
||||
V 0 255
|
||||
在黄掩码上找轮廓后,还要满足:面积 > 50,圆度 > 0.7(circularity = 4π·面积/周长²),且点数 ≥5 才 fitEllipse 当黄心椭圆。
|
||||
|
||||
2. 红色怎么判、范围是什么?
|
||||
红色在 HSV 里跨 0°,所以用 两段 H 做并集:
|
||||
两段分别是:
|
||||
H 0–10,S 80–255,V 0–255
|
||||
H 170–180,S 80–255,V 0–255
|
||||
红轮廓候选:面积 > 50,圆度 > 0.6(比黄略松),再拟合椭圆或最小外接圆得到圆心和半径。
|
||||
|
||||
3. 「黄心」和「红心」怎样算一对?(几何范围)
|
||||
对每个黄圈,在红色候选里找第一个满足:
|
||||
|
||||
两圆心距离 dist_centers < yellow_radius * 1.5
|
||||
红半径 red_radius > yellow_radius * 0.8(红在外圈、略大)
|
||||
dist_centers = math.hypot(ddx, ddy)
|
||||
if dist_centers < yellow_radius * 1.5 and rc["radius"] > yellow_radius * 0.8:
|
||||
小结:黄色 = HSV H∈[7,32]、S≥80(且 S 放大 1.1)+ 形态学闭运算 + 面积/圆度;红色 = 两段 H(0–10 与 170–180)、S≥80 + 闭运算 + 面积/圆度;配对用 同心/包含 的距离与半径比例阈值。若你还关心 laser_manager.py 里「激光红点」的另一套阈值(LASER_*),那是另一条链路,和靶心黄/红 HSV 可以分开看。
|
||||
|
||||
14. 算环数算法2:
|
||||
使用单应性矩阵计算:镜头中心点(照片中心像素)到虚拟平面的转换。它不需要知道相机在 3D 空间中的具体位置,直接通过单应性矩阵 H的逆运算,将 2D 像素“翻译”成虚拟平面上的 2D 坐标。
|
||||
|
||||
一、转换的本质:2D 到 2D 的“查字典”
|
||||
单应性变换(Homography)是平面到平面的映射。它不处理 3D 空间中的“投影线”,而是直接建立图像像素 (u,v) 与虚拟平面坐标 (x,y) 的一一对应关系。
|
||||
你可以把单应性矩阵 H想象成一本“翻译字典”:
|
||||
正变换 H:已知靶纸上的真实位置 (x,y),查字典得到它在照片上哪个像素 (u,v)。
|
||||
逆变换 H−1:已知照片上的像素 (u,v)(如镜头中心点),查字典反推它在靶纸上的真实位置 (x,y)。
|
||||
这个“虚拟平面”就是你的靶纸平面(Z=0 的世界坐标系)。算法没有在物理上移动任何点,只是在做坐标系的换算。
|
||||
|
||||
二、详细步骤:镜头中心点如何“落地”
|
||||
|
||||
相机分辨率是 640x480,镜头中心点(光轴与图像的交点)通常是 (u0,v0)=(320,240)。
|
||||
1. 输入:镜头中心点(像素)
|
||||
2. 核心运算:乘以逆矩阵
|
||||
通过 4 个黑色三角形的角点(已知真实坐标)计算出了单应性矩阵 H。现在使用它的逆矩阵 H−1
|
||||
3. 输出:虚拟平面上的落点(物理坐标)
|
||||
计算后,你会得到:(xhit,yhit)
|
||||
这就是镜头中心点对应的靶纸上的真实位置(单位:毫米)。
|
||||
4. 计算环数
|
||||
由于虚拟平面原点 (0,0)就是靶纸圆心,直接计算欧氏距离。
|
||||
这个 d就是箭着点偏离圆心的真实物理距离,直接用于环数判定。
|
||||
25
design_doc/todo.md
Normal file
25
design_doc/todo.md
Normal file
@@ -0,0 +1,25 @@
|
||||
你现在要防的是“别人拿到设备/拿到代码包后,能伪造请求、刷接口、下发恶意 OTA、甚至劫持通信”。单靠隐藏 Python 源码只能提高门槛,真正的安全要靠协议和密钥设计。结合你仓库里实际内容,建议你重点隐藏/整改这些点(按风险排序)。
|
||||
|
||||
|
||||
1.4 OTA 下发“url”如果缺少强校验,就是远程代码执行入口
|
||||
你 OTA 逻辑里从服务器指令拿到 url 就去下载并替换文件/重启(这是正常 OTA),但安全性取决于:
|
||||
是否只允许白名单域名/路径
|
||||
是否强制 https 并校验服务器证书
|
||||
是否对 OTA 包做签名校验(最关键)
|
||||
你这里能看到固定域名 static.shelingxingqiu.com 的特殊处理(ota_manager.py 里还在纠结 http/https),这块一定要“服务端签名 + 设备端验签”,否则隐藏源码也没用。
|
||||
|
||||
2. 建议隐藏(中风险,但很容易被人利用)
|
||||
2.1 所有服务器地址/端口/API 路径(可被用于扫描、压测、撞库、协议逆向)
|
||||
这些在 config.py 是明文:
|
||||
SERVER_IP = "stcp.shelingxingqiu.com"SERVER_PORT = 50005HTTP_URL = "http://ws.shelingxingqiu.com"HTTP_API_PATH = "/home/shoot/device_fire/arrow/fire"
|
||||
注意:即使你把它们藏进 .so,攻击者仍可能通过抓包/观察 DNS/SNI/流量拿到域名或 IP,所以“隐藏”只能降低静态分析风险,不能替代鉴权/签名。
|
||||
|
||||
|
||||
|
||||
3. 最有效的“防黑”措施(比隐藏更关键)
|
||||
|
||||
所有上报/控制命令加签名 + nonce/timestamp + 服务端防重放(别人抓到一次包也不能复用)
|
||||
OTA 包必须做签名校验(设备端内置公钥,下载后验签通过才应用)
|
||||
|
||||
如果你告诉我:你们服务端目前能不能改协议(例如新增签名字段、下发 challenge、做 OTA 签名),我可以按“最小改动但提升最大安全”的顺序,帮你规划一套从现状平滑升级的方案。
|
||||
|
||||
133
hardware.py
Normal file
133
hardware.py
Normal file
@@ -0,0 +1,133 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
硬件管理器模块
|
||||
提供硬件对象的统一管理和访问
|
||||
"""
|
||||
from maix import time
|
||||
import config
|
||||
from at_client import ATClient
|
||||
|
||||
|
||||
class HardwareManager:
|
||||
"""硬件管理器(单例)"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(HardwareManager, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 私有硬件对象
|
||||
self._uart4g = None # 4G模块UART
|
||||
self._bus = None # I2C总线
|
||||
self._adc_obj = None # ADC对象
|
||||
self._at_client = None # AT客户端
|
||||
|
||||
self._last_active_time = 0 # 用于记录用户的最后一次活跃的时间
|
||||
self._stop_timer = False # 用于停止定时器的标志
|
||||
|
||||
self._initialized = True
|
||||
|
||||
|
||||
|
||||
# ==================== 硬件访问(只读属性)====================
|
||||
|
||||
@property
|
||||
def uart4g(self):
|
||||
"""4G模块UART(只读)"""
|
||||
return self._uart4g
|
||||
|
||||
@property
|
||||
def bus(self):
|
||||
"""I2C总线(只读)"""
|
||||
return self._bus
|
||||
|
||||
@property
|
||||
def adc_obj(self):
|
||||
"""ADC对象(只读)"""
|
||||
return self._adc_obj
|
||||
|
||||
@property
|
||||
def at_client(self):
|
||||
"""AT客户端(只读)"""
|
||||
return self._at_client
|
||||
|
||||
# ==================== 初始化方法 ====================
|
||||
|
||||
def init_uart4g(self, device=None, baudrate=None):
|
||||
"""初始化4G模块UART"""
|
||||
from maix import uart
|
||||
if device is None:
|
||||
device = config.UART4G_DEVICE
|
||||
if baudrate is None:
|
||||
baudrate = config.UART4G_BAUDRATE
|
||||
self._uart4g = uart.UART(device, baudrate)
|
||||
return self._uart4g
|
||||
|
||||
def init_bus(self, bus_num=None):
|
||||
"""初始化I2C总线"""
|
||||
from maix import i2c
|
||||
if bus_num is None:
|
||||
bus_num = config.I2C_BUS_NUM
|
||||
self._bus = i2c.I2C(bus_num, i2c.Mode.MASTER)
|
||||
return self._bus
|
||||
|
||||
def init_adc(self, channel=None, res_bit=None):
|
||||
"""初始化ADC"""
|
||||
from maix.peripheral import adc
|
||||
if channel is None:
|
||||
channel = config.ADC_CHANNEL
|
||||
if res_bit is None:
|
||||
res_bit = adc.RES_BIT_12
|
||||
self._adc_obj = adc.ADC(channel, res_bit)
|
||||
return self._adc_obj
|
||||
|
||||
def init_at_client(self, uart_obj=None):
|
||||
"""初始化AT客户端"""
|
||||
if uart_obj is None:
|
||||
if self._uart4g is None:
|
||||
raise ValueError("uart4g must be initialized before at_client")
|
||||
uart_obj = self._uart4g
|
||||
self._at_client = ATClient(uart_obj)
|
||||
self._at_client.start()
|
||||
return self._at_client
|
||||
|
||||
def power_off(self):
|
||||
"""关闭电源板"""
|
||||
try:
|
||||
# 物理引脚是 A24,对应 GPIO 功能是 GPIOA24
|
||||
# 注意:这里需要先在 config.PIN_MAPPINGS 中配置好 "A24": "GPIOA24"
|
||||
from maix import gpio
|
||||
# 输出高电平关闭
|
||||
gpio.GPIO("GPIOA24", gpio.Mode.OUT).value(1)
|
||||
except Exception as e:
|
||||
print(f"关机失败: {e}")
|
||||
|
||||
def start_idle_timer(self):
|
||||
self._stop_timer = False
|
||||
self._last_active_time = time.time()
|
||||
|
||||
def stop_idle_timer(self):
|
||||
self._stop_timer = True
|
||||
|
||||
def get_idle_time_in_sec(self):
|
||||
if self._stop_timer:
|
||||
return 0
|
||||
diff = time.time() - self._last_active_time
|
||||
if diff < 0:
|
||||
# 时间可能被重置了,重新计时
|
||||
self._last_active_time = time.time()
|
||||
return 0
|
||||
return diff
|
||||
|
||||
|
||||
# 创建全局单例实例
|
||||
hardware_manager = HardwareManager()
|
||||
|
||||
|
||||
52
keygen.py
Normal file
52
keygen.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
|
||||
|
||||
def generate_key_pair():
|
||||
"""
|
||||
生成一对新的密钥a和b,使得a XOR b等于原始key
|
||||
:return: (a, b, key) 元组,每个元素都是32字节的字节数组
|
||||
"""
|
||||
# 原始key值
|
||||
key = bytes([
|
||||
0x5d, 0xf9, 0xef, 0xc4, 0x5d, 0xcc, 0xc7, 0x8d, 0xc9, 0x86, 0x34, 0x11, 0x6f, 0xb4, 0xcf, 0x75,
|
||||
0xbf, 0x24, 0x47, 0x9d, 0xd6, 0x5d, 0x83, 0x4b, 0xa6, 0xc0, 0xde, 0x27, 0x91, 0x92, 0xb1, 0x63
|
||||
])
|
||||
|
||||
# 随机生成a
|
||||
a = os.urandom(32)
|
||||
|
||||
# 计算b = key XOR a
|
||||
b = bytes([key[i] ^ a[i] for i in range(32)])
|
||||
|
||||
return a, b, key
|
||||
|
||||
|
||||
def format_hex_array(data):
|
||||
"""
|
||||
将字节数组格式化为C++风格的十六进制数组
|
||||
:param data: 字节数组
|
||||
:return: 格式化后的字符串
|
||||
"""
|
||||
return "{" + ",".join([f"0x{b:02x}" for b in data]) + "}"
|
||||
|
||||
|
||||
def generate_new_key_pair():
|
||||
"""
|
||||
生成新的密钥对并打印出来
|
||||
"""
|
||||
a, b, key = generate_key_pair()
|
||||
|
||||
print("原始key:")
|
||||
print(format_hex_array(key))
|
||||
print("\n新的密钥对:")
|
||||
print("a =", format_hex_array(a))
|
||||
print("b =", format_hex_array(b))
|
||||
|
||||
# 验证a XOR b是否等于key
|
||||
verify_key = bytes([a[i] ^ b[i] for i in range(32)])
|
||||
assert verify_key == key, "验证失败:a XOR b 不等于 key"
|
||||
print("\n验证成功:a XOR b 等于 key")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_new_key_pair()
|
||||
826
laser.py
826
laser.py
@@ -1,826 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
激光射击系统主程序(激光测距版)
|
||||
功能:目标检测、激光校准、4G TCP 通信、OTA 升级、M01 激光测距、INA226 电量监测
|
||||
平台:MaixPy (Sipeed MAIX)
|
||||
作者:ZZH
|
||||
最后更新:2025-11-21
|
||||
"""
|
||||
|
||||
from maix import camera, display, image, app, time, key, uart, pinmap, i2c, network, err
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import struct
|
||||
import re
|
||||
from maix.peripheral import adc
|
||||
import _thread
|
||||
import os
|
||||
import requests
|
||||
import socket
|
||||
import binascii
|
||||
|
||||
# ==============================
|
||||
# 全局配置
|
||||
# ==============================
|
||||
# OTA 升级地址(建议后续改为动态下发)
|
||||
url = "https://static.shelingxingqiu.com/shoot/202511031031/main.py"
|
||||
local_filename = "/maixapp/apps/t11/main.py"
|
||||
|
||||
DEVICE_ID = None
|
||||
PASSWORD = None
|
||||
SERVER_IP = "www.shelingxingqiu.com"
|
||||
SERVER_PORT = 50005
|
||||
HEARTBEAT_INTERVAL = 2 # 心跳间隔(秒)
|
||||
|
||||
CONFIG_FILE = "/root/laser_config.json"
|
||||
DEFAULT_POINT = (640, 480) # 图像中心点
|
||||
laser_point = DEFAULT_POINT
|
||||
|
||||
# HTTP API(当前未使用,保留备用)
|
||||
URL = "http://ws.shelingxingqiu.com"
|
||||
API_PATH = "/home/shoot/device_fire/arrow/fire"
|
||||
|
||||
# UART 设备初始化
|
||||
uart4g = uart.UART("/dev/ttyS2", 115200) # 4G 模块(TCP 透传)
|
||||
distance_serial = uart.UART("/dev/ttyS1", 9600) # M01 激光测距模块
|
||||
|
||||
# 消息类型常量
|
||||
MSG_TYPE_LOGIN_REQ = 1 # 登录请求
|
||||
MSG_TYPE_STATUS = 2 # 状态上报
|
||||
MSG_TYPE_HEARTBEAT = 4 # 心跳包
|
||||
# 引脚功能映射
|
||||
pinmap.set_pin_function("A18", "UART1_RX")
|
||||
pinmap.set_pin_function("A19", "UART1_TX")
|
||||
pinmap.set_pin_function("A29", "UART2_RX")
|
||||
pinmap.set_pin_function("A28", "UART2_TX")
|
||||
pinmap.set_pin_function("P18", "I2C1_SCL")
|
||||
pinmap.set_pin_function("P21", "I2C1_SDA")
|
||||
# pinmap.set_pin_function("A15", "I2C5_SCL")
|
||||
# pinmap.set_pin_function("A27", "I2C5_SDA")#ota升级要修改的
|
||||
# ADC 触发阈值(用于检测扳机/激光触发)
|
||||
ADC_TRIGGER_THRESHOLD = 3000
|
||||
ADC_LASER_THRESHOLD = 3000
|
||||
# 显示参数
|
||||
color = image.Color(255, 100, 0) # 橙色十字线
|
||||
thickness = 1
|
||||
length = 2
|
||||
|
||||
# ADC 扳机触发阈值(0~4095)
|
||||
ADC_TRIGGER_THRESHOLD = 3000
|
||||
|
||||
# I2C 电源监测(INA226)
|
||||
adc_obj = adc.ADC(0, adc.RES_BIT_12)
|
||||
bus = i2c.I2C(1, i2c.Mode.MASTER)
|
||||
# bus = i2c.I2C(5, i2c.Mode.MASTER)#ota升级总线
|
||||
INA226_ADDR = 0x40
|
||||
REG_CONFIGURATION = 0x00
|
||||
REG_BUS_VOLTAGE = 0x02
|
||||
REG_CALIBRATION = 0x05
|
||||
CALIBRATION_VALUE = 0x1400
|
||||
|
||||
# M01 激光模块指令
|
||||
MODULE_ADDR = 0x00
|
||||
LASER_ON_CMD = bytes([0xAA, MODULE_ADDR, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x01, 0xC1])
|
||||
LASER_OFF_CMD = bytes([0xAA, MODULE_ADDR, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x00, 0xC0])
|
||||
DISTANCE_QUERY_CMD = bytes([0xAA, MODULE_ADDR, 0x00, 0x20, 0x00, 0x01, 0x00, 0x00, 0x21])
|
||||
DISTANCE_RESPONSE_LEN = 13
|
||||
|
||||
# TCP / 线程状态
|
||||
tcp_connected = False
|
||||
send_queue = []
|
||||
update_thread_started = False # 防止重复 OTA
|
||||
send_queue_lock = _thread.allocate_lock()
|
||||
laser_calibration_data_lock = _thread.allocate_lock()
|
||||
laser_calibration_active = False
|
||||
laser_calibration_result = None
|
||||
|
||||
|
||||
# ==============================
|
||||
# 网络工具函数
|
||||
# ==============================
|
||||
|
||||
def is_server_reachable(host, port=80, timeout=5):
|
||||
"""检查能否连接到指定主机和端口(用于 OTA 前网络检测)"""
|
||||
try:
|
||||
addr_info = socket.getaddrinfo(host, port)[0]
|
||||
s = socket.socket(addr_info[0], addr_info[1], addr_info[2])
|
||||
s.settimeout(timeout)
|
||||
s.connect(addr_info[-1])
|
||||
s.close()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[NET] 无法连接 {host}:{port} - {e}")
|
||||
return False
|
||||
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
从指定 URL 下载文件并保存为 UTF-8 文本。
|
||||
注意:此操作会覆盖本地 main.py!
|
||||
"""
|
||||
try:
|
||||
print(f"[OTA] 正在从 {url} 下载文件...")
|
||||
response = requests.get(url, timeout=10) # ⏱️ 防止卡死
|
||||
response.raise_for_status()
|
||||
response.encoding = 'utf-8'
|
||||
with open(filename, 'w', encoding='utf-8') as file:
|
||||
file.write(response.text)
|
||||
return f"下载成功!文件已保存为: {filename}"
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"下载失败!网络请求错误: {e}"
|
||||
except OSError as e:
|
||||
return f"下载失败!文件写入错误: {e}"
|
||||
except Exception as e:
|
||||
return f"下载失败!发生未知错误: {e}"
|
||||
|
||||
|
||||
def connect_wifi(ssid, password):
|
||||
"""
|
||||
连接 Wi-Fi 并持久化凭证到 /boot/ 目录,使设备重启后自动连接。
|
||||
返回 (ip, error) 元组。
|
||||
"""
|
||||
conf_path = "/etc/wpa_supplicant.conf"
|
||||
ssid_file = "/boot/wifi.ssid"
|
||||
pass_file = "/boot/wifi.pass"
|
||||
|
||||
try:
|
||||
# 生成 wpa_supplicant 配置
|
||||
net_conf = os.popen(f'wpa_passphrase "{ssid}" "{password}"').read()
|
||||
if "network={" not in net_conf:
|
||||
return None, "Failed to generate wpa config"
|
||||
|
||||
# 写入运行时配置
|
||||
with open(conf_path, "w") as f:
|
||||
f.write("ctrl_interface=/var/run/wpa_supplicant\n")
|
||||
f.write("update_config=1\n\n")
|
||||
f.write(net_conf)
|
||||
|
||||
# 持久化保存(供开机脚本读取)
|
||||
with open(ssid_file, "w") as f:
|
||||
f.write(ssid.strip())
|
||||
with open(pass_file, "w") as f:
|
||||
f.write(password.strip())
|
||||
|
||||
# 重启 Wi-Fi 服务
|
||||
os.system("/etc/init.d/S30wifi restart")
|
||||
|
||||
# 等待获取 IP(最多 20 秒)
|
||||
for _ in range(20):
|
||||
ip = os.popen("ifconfig wlan0 2>/dev/null | grep 'inet ' | awk '{print $2}'").read().strip()
|
||||
if ip:
|
||||
return ip, None
|
||||
time.sleep(1)
|
||||
|
||||
return None, "Timeout: No IP obtained"
|
||||
|
||||
except Exception as e:
|
||||
return None, f"Exception: {str(e)}"
|
||||
|
||||
def direct_ota_download():
|
||||
"""
|
||||
直接执行 OTA 下载(假设已有网络)
|
||||
用于 cmd=7 触发
|
||||
"""
|
||||
global update_thread_started
|
||||
try:
|
||||
# 再次确认网络可达(可选但推荐)
|
||||
from urllib.parse import urlparse
|
||||
parsed_url = urlparse(url)
|
||||
host = parsed_url.hostname
|
||||
port = parsed_url.port or (443 if parsed_url.scheme == 'https' else 80)
|
||||
|
||||
if not is_server_reachable(host, port, timeout=8):
|
||||
safe_enqueue({"result": "ota_failed", "reason": f"无法连接 {host}:{port}"}, MSG_TYPE_STATUS)
|
||||
return
|
||||
|
||||
print(f"[OTA] 开始直接下载固件...")
|
||||
result_msg = download_file(url, local_filename)
|
||||
print(f"[OTA] {result_msg}")
|
||||
safe_enqueue({"result": result_msg}, MSG_TYPE_STATUS)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"OTA 异常: {str(e)}"
|
||||
print(error_msg)
|
||||
safe_enqueue({"result": "ota_failed", "reason": error_msg}, MSG_TYPE_STATUS)
|
||||
finally:
|
||||
update_thread_started = False # 允许下次 OTA
|
||||
|
||||
|
||||
def handle_wifi_and_update(ssid, password):
|
||||
"""
|
||||
OTA 更新线程入口。
|
||||
注意:必须在 finally 中重置 update_thread_started!
|
||||
"""
|
||||
global update_thread_started
|
||||
try:
|
||||
ip, error = connect_wifi(ssid, password)
|
||||
if error:
|
||||
safe_enqueue({"result": "wifi_failed", "error": error}, MSG_TYPE_STATUS)
|
||||
return
|
||||
|
||||
safe_enqueue({"result": "wifi_connected", "ip": ip}, MSG_TYPE_STATUS)
|
||||
|
||||
from urllib.parse import urlparse
|
||||
parsed_url = urlparse(url)
|
||||
host = parsed_url.hostname
|
||||
port = parsed_url.port or (443 if parsed_url.scheme == 'https' else 80)
|
||||
|
||||
if not is_server_reachable(host, port, timeout=8):
|
||||
err_msg = f"网络不通:无法连接 {host}:{port}"
|
||||
safe_enqueue({"result": err_msg}, MSG_TYPE_STATUS)
|
||||
return
|
||||
|
||||
print(f"[OTA] 已确认可访问 {host}:{port},开始下载...")
|
||||
try:
|
||||
cs = download_file(url, local_filename)
|
||||
except Exception as e:
|
||||
cs = f"下载失败: {str(e)}"
|
||||
print(cs)
|
||||
safe_enqueue({"result": cs}, MSG_TYPE_STATUS)
|
||||
|
||||
finally:
|
||||
# ✅ 关键修复:允许下次 OTA
|
||||
update_thread_started = False
|
||||
print("[UPDATE] OTA 线程执行完毕,标志已重置。")
|
||||
|
||||
|
||||
# ==============================
|
||||
# 工具函数
|
||||
# ==============================
|
||||
|
||||
def read_device_id():
|
||||
"""从 /device_key 读取设备唯一 ID"""
|
||||
try:
|
||||
with open("/device_key", "r") as f:
|
||||
device_id = f.read().strip()
|
||||
if device_id:
|
||||
print(f"[INFO] 从 /device_key 读取到 DEVICE_ID: {device_id}")
|
||||
return device_id
|
||||
else:
|
||||
raise ValueError("文件为空")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 无法读取 /device_key: {e}")
|
||||
return "DEFAULT_DEVICE_ID"
|
||||
|
||||
|
||||
def safe_enqueue(data_dict, msg_type=MSG_TYPE_STATUS):
|
||||
"""线程安全地将消息加入发送队列"""
|
||||
global send_queue, send_queue_lock
|
||||
with send_queue_lock:
|
||||
send_queue.append((msg_type, data_dict))
|
||||
|
||||
|
||||
def at(cmd, wait="OK", timeout=2000):
|
||||
"""向 4G 模块发送 AT 指令并等待响应"""
|
||||
if cmd:
|
||||
uart4g.write((cmd + "\r\n").encode())
|
||||
t0 = time.ticks_ms()
|
||||
buf = b""
|
||||
while time.ticks_ms() - t0 < timeout:
|
||||
data = uart4g.read()
|
||||
if data:
|
||||
buf += data
|
||||
if wait.encode() in buf:
|
||||
return buf.decode(errors="ignore")
|
||||
return buf.decode(errors="ignore")
|
||||
|
||||
|
||||
def make_packet(msg_type: int, body_dict: dict) -> bytes:
|
||||
"""构造二进制数据包:[body_len][msg_type][checksum][body]"""
|
||||
body = json.dumps(body_dict, ensure_ascii=False).encode('utf-8')
|
||||
body_len = len(body)
|
||||
checksum = body_len + msg_type
|
||||
header = struct.pack(">III", body_len, msg_type, checksum)
|
||||
return header + body
|
||||
|
||||
|
||||
def parse_packet(data: bytes):
|
||||
"""解析二进制数据包"""
|
||||
if len(data) < 12:
|
||||
return None, None
|
||||
body_len, msg_type, checksum = struct.unpack(">III", data[:12])
|
||||
body = data[12:12 + body_len]
|
||||
try:
|
||||
# ✅ 显式指定 UTF-8 编码
|
||||
return msg_type, json.loads(body.decode('utf-8'))
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 解析包体失败: {e}")
|
||||
return msg_type, {"raw": body.decode('utf-8', errors='ignore')}
|
||||
|
||||
|
||||
def tcp_send_raw(data: bytes, max_retries=2) -> bool:
|
||||
"""通过 4G 模块发送原始 TCP 数据(仅在 tcp_main 线程调用)"""
|
||||
global tcp_connected
|
||||
if not tcp_connected:
|
||||
return False
|
||||
|
||||
for attempt in range(max_retries):
|
||||
cmd = f'AT+MIPSEND=0,{len(data)}'
|
||||
if ">" not in at(cmd, ">", 1500):
|
||||
time.sleep_ms(100)
|
||||
continue
|
||||
|
||||
time.sleep_ms(10)
|
||||
full = data + b"\x1A"
|
||||
try:
|
||||
sent = uart4g.write(full)
|
||||
if sent != len(full):
|
||||
time.sleep_ms(100)
|
||||
continue
|
||||
except:
|
||||
time.sleep_ms(100)
|
||||
continue
|
||||
|
||||
if "OK" in at("", "OK", 1000):
|
||||
return True
|
||||
time.sleep_ms(100)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def load_laser_point():
|
||||
"""从配置文件加载激光点坐标"""
|
||||
global laser_point
|
||||
try:
|
||||
if "laser_config.json" in os.listdir("/root"):
|
||||
with open(CONFIG_FILE, "r") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, list) and len(data) == 2:
|
||||
laser_point = (int(data[0]), int(data[1]))
|
||||
print(f"[INFO] 加载激光点: {laser_point}")
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
laser_point = DEFAULT_POINT
|
||||
except:
|
||||
laser_point = DEFAULT_POINT
|
||||
|
||||
|
||||
def save_laser_point(point):
|
||||
"""保存激光点坐标到文件"""
|
||||
global laser_point
|
||||
try:
|
||||
with open(CONFIG_FILE, "w") as f:
|
||||
json.dump([point[0], point[1]], f)
|
||||
laser_point = point
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def turn_on_laser():
|
||||
"""发送激光开启指令"""
|
||||
distance_serial.write(LASER_ON_CMD)
|
||||
time.sleep_ms(10)
|
||||
resp = distance_serial.read(20)
|
||||
if resp:
|
||||
if resp == LASER_ON_CMD:
|
||||
print("✅ 激光指令已确认")
|
||||
else:
|
||||
print("🔇 无回包(正常或模块不支持)")
|
||||
return resp
|
||||
|
||||
|
||||
# ==============================
|
||||
# M01 激光测距模块
|
||||
# ==============================
|
||||
|
||||
def parse_bcd_distance(bcd_bytes: bytes) -> float:
|
||||
"""将 4 字节 BCD 码转换为距离(米)"""
|
||||
if len(bcd_bytes) != 4:
|
||||
return 0.0
|
||||
try:
|
||||
hex_string = binascii.hexlify(bcd_bytes).decode()
|
||||
distance_int = int(hex_string)
|
||||
return distance_int / 1000.0
|
||||
except Exception as e:
|
||||
print(f"[ERROR] BCD 解析失败: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def read_distance_from_laser_sensor():
|
||||
"""发送测距指令并返回距离(米)"""
|
||||
global distance_serial
|
||||
try:
|
||||
distance_serial.read() # 清空缓冲区
|
||||
distance_serial.write(DISTANCE_QUERY_CMD)
|
||||
time.sleep_ms(500)
|
||||
response = distance_serial.read(DISTANCE_RESPONSE_LEN)
|
||||
|
||||
if response and len(response) == DISTANCE_RESPONSE_LEN:
|
||||
if response[3] != 0x20:
|
||||
if response[0] == 0xEE:
|
||||
err_code = (response[7] << 8) | response[8]
|
||||
print(f"[LASER] 模块错误代码: {hex(err_code)}")
|
||||
return 0.0
|
||||
|
||||
bcd_bytes = response[6:10]
|
||||
distance_value_m = parse_bcd_distance(bcd_bytes)
|
||||
signal_quality = (response[10] << 8) | response[11]
|
||||
print(f"[LASER] 测距成功: {distance_value_m:.3f} m, 信号质量: {signal_quality}")
|
||||
return distance_value_m
|
||||
|
||||
print(f"[LASER] 无效响应: {response.hex() if response else 'None'}")
|
||||
return 0.0
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 读取激光测距失败: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
# ==============================
|
||||
# 激光点校准
|
||||
# ==============================
|
||||
|
||||
def find_red_laser(frame, threshold=150):
|
||||
"""在图像中查找最亮的红色点(简单 RGB 判定)"""
|
||||
w, h = frame.width(), frame.height()
|
||||
img_bytes = frame.to_bytes()
|
||||
max_sum = 0
|
||||
best_pos = None
|
||||
for y in range(0, h, 2):
|
||||
for x in range(0, w, 2):
|
||||
idx = (y * w + x) * 3
|
||||
r, g, b = img_bytes[idx], img_bytes[idx+1], img_bytes[idx+2]
|
||||
if r > threshold and r > g * 2 and r > b * 2:
|
||||
rgb_sum = r + g + b
|
||||
if rgb_sum > max_sum:
|
||||
max_sum = rgb_sum
|
||||
best_pos = (x, y)
|
||||
return best_pos
|
||||
|
||||
|
||||
def calibrate_laser_position():
|
||||
"""拍摄一帧并识别激光点位置"""
|
||||
time.sleep_ms(80)
|
||||
cam = camera.Camera(640, 480)
|
||||
frame = cam.read()
|
||||
pos = find_red_laser(frame)
|
||||
if pos:
|
||||
save_laser_point(pos)
|
||||
return pos
|
||||
return None
|
||||
|
||||
|
||||
# ==============================
|
||||
# 电量监测(INA226)
|
||||
# ==============================
|
||||
|
||||
def write_register(reg, value):
|
||||
data = [(value >> 8) & 0xFF, value & 0xFF]
|
||||
bus.writeto_mem(INA226_ADDR, reg, bytes(data))
|
||||
|
||||
|
||||
def read_register(reg):
|
||||
data = bus.readfrom_mem(INA226_ADDR, reg, 2)
|
||||
return (data[0] << 8) | data[1]
|
||||
|
||||
|
||||
def init_ina226():
|
||||
write_register(REG_CONFIGURATION, 0x4527)
|
||||
write_register(REG_CALIBRATION, CALIBRATION_VALUE)
|
||||
|
||||
|
||||
def get_bus_voltage():
|
||||
raw = read_register(REG_BUS_VOLTAGE)
|
||||
return raw * 1.25 / 1000
|
||||
|
||||
|
||||
def voltage_to_percent(voltage):
|
||||
points = [
|
||||
(4.20, 100), (4.10, 95), (4.05, 85), (4.00, 75), (3.95, 65),
|
||||
(3.90, 55), (3.85, 45), (3.80, 35), (3.75, 25), (3.70, 15),
|
||||
(3.65, 5), (3.60, 0)
|
||||
]
|
||||
if voltage >= points[0][0]: return 100
|
||||
if voltage <= points[-1][0]: return 0
|
||||
for i in range(len(points) - 1):
|
||||
v1, p1 = points[i]; v2, p2 = points[i + 1]
|
||||
if v2 <= voltage <= v1:
|
||||
ratio = (voltage - v1) / (v2 - v1)
|
||||
percent = p1 + (p2 - p1) * ratio
|
||||
return max(0, min(100, int(round(percent))))
|
||||
return 0
|
||||
|
||||
|
||||
# ==============================
|
||||
# 目标检测
|
||||
# ==============================
|
||||
|
||||
def detect_circle(frame):
|
||||
"""检测靶心圆(清晰/模糊两种模式)"""
|
||||
img_cv = image.image2cv(frame, False, False)
|
||||
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
|
||||
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
|
||||
edged = cv2.Canny(blurred, 50, 150)
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
||||
ceroded = cv2.erode(cv2.dilate(edged, kernel), kernel)
|
||||
|
||||
contours, _ = cv2.findContours(ceroded, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
best_center = best_radius = method = None
|
||||
|
||||
for cnt in contours:
|
||||
area = cv2.contourArea(cnt)
|
||||
perimeter = cv2.arcLength(cnt, True)
|
||||
if perimeter < 100 or area < 100: continue
|
||||
circularity = 4 * np.pi * area / (perimeter ** 2)
|
||||
if circularity > 0.75 and len(cnt) >= 5:
|
||||
center, axes, angle = cv2.fitEllipse(cnt)
|
||||
radius = (axes[0] + axes[1]) / 4
|
||||
best_center = (int(center[0]), int(center[1]))
|
||||
best_radius = int(radius)
|
||||
method = "清晰"
|
||||
break
|
||||
|
||||
if not best_center:
|
||||
hsv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2HSV)
|
||||
h, s, v = cv2.split(hsv)
|
||||
s = np.clip(s * 2, 0, 255).astype(np.uint8)
|
||||
hsv = cv2.merge((h, s, v))
|
||||
lower_yellow = np.array([7, 80, 0])
|
||||
upper_yellow = np.array([32, 255, 182])
|
||||
mask = cv2.inRange(hsv, lower_yellow, upper_yellow)
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
||||
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
||||
mask = cv2.morphologyEx(mask, cv2.MORPH_DILATE, kernel)
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
if contours:
|
||||
largest = max(contours, key=cv2.contourArea)
|
||||
if cv2.contourArea(largest) > 50:
|
||||
(x, y), radius = cv2.minEnclosingCircle(largest)
|
||||
best_center = (int(x), int(y))
|
||||
best_radius = int(radius)
|
||||
method = "模糊"
|
||||
|
||||
result_img = image.cv2image(img_cv, False, False)
|
||||
return result_img, best_center, best_radius, method, best_radius
|
||||
|
||||
|
||||
def compute_laser_position(circle_center, laser_point, radius, method):
|
||||
"""计算激光相对于靶心的偏差(单位:厘米)"""
|
||||
if not all([circle_center, radius, method]):
|
||||
return None, None
|
||||
cx, cy = circle_center
|
||||
lx, ly = laser_point
|
||||
# 根据检测模式估算实际半径(单位:像素 → 厘米)
|
||||
circle_r_cm = (radius / 4.0) * 20.0 if method == "模糊" else (68 / 16.0) * 20.0
|
||||
dx = lx - cx
|
||||
dy = ly - cy
|
||||
scale = circle_r_cm / radius if radius != 0 else 1.0
|
||||
return dx * scale, -dy * scale
|
||||
|
||||
|
||||
# ==============================
|
||||
# TCP 通信主线程
|
||||
# ==============================
|
||||
|
||||
def connect_server():
|
||||
"""连接服务器(通过 4G 模块 AT 指令)"""
|
||||
global tcp_connected
|
||||
if tcp_connected:
|
||||
return True
|
||||
print("正在连接服务器...")
|
||||
at("AT+MIPCLOSE=0", "OK", 1000)
|
||||
res = at(f'AT+MIPOPEN=0,"TCP","{SERVER_IP}",{SERVER_PORT}', "+MIPOPEN", 8000)
|
||||
if "+MIPOPEN: 0,0" in res:
|
||||
tcp_connected = True
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def tcp_main():
|
||||
"""TCP 通信主循环(独立线程)"""
|
||||
global tcp_connected, send_queue, laser_calibration_active, laser_calibration_result,update_thread_started
|
||||
|
||||
while not app.need_exit():
|
||||
if not connect_server():
|
||||
time.sleep_ms(5000)
|
||||
continue
|
||||
|
||||
login_data = {"deviceId": DEVICE_ID, "password": PASSWORD}
|
||||
if not tcp_send_raw(make_packet(MSG_TYPE_LOGIN_REQ, login_data)):
|
||||
tcp_connected = False
|
||||
time.sleep_ms(2000)
|
||||
continue
|
||||
|
||||
print("➡️ 登录包已发送,等待确认...")
|
||||
logged_in = False
|
||||
last_heartbeat_ack_time = time.ticks_ms()
|
||||
last_heartbeat_send_time = time.ticks_ms()
|
||||
rx_buf = b""
|
||||
|
||||
while True:
|
||||
data = uart4g.read()
|
||||
if data:
|
||||
rx_buf += data
|
||||
while b'+MIPURC: "rtcp"' in rx_buf:
|
||||
try:
|
||||
match = re.search(b'\+MIPURC: "rtcp",0,(\d+),(.+)', rx_buf, re.DOTALL)
|
||||
if match:
|
||||
payload_len = int(match.group(1))
|
||||
payload = match.group(2)[:payload_len]
|
||||
msg_type, body = parse_packet(payload)
|
||||
|
||||
if not logged_in and msg_type == MSG_TYPE_LOGIN_REQ:
|
||||
if body and body.get("cmd") == 1 and body.get("data") == "登录成功":
|
||||
logged_in = True
|
||||
last_heartbeat_ack_time = time.ticks_ms()
|
||||
print("✅ 登录成功")
|
||||
else:
|
||||
break
|
||||
|
||||
elif logged_in and msg_type == MSG_TYPE_HEARTBEAT:
|
||||
last_heartbeat_ack_time = time.ticks_ms()
|
||||
print("✅ 收到心跳确认")
|
||||
|
||||
elif logged_in and isinstance(body, dict):
|
||||
inner_data = body.get("data", {})
|
||||
if isinstance(inner_data, dict) and "cmd" in inner_data:
|
||||
inner_cmd = inner_data["cmd"]
|
||||
if inner_cmd == 2:
|
||||
turn_on_laser()
|
||||
time.sleep_ms(100)
|
||||
laser_calibration_active = True
|
||||
safe_enqueue({"result": "calibrating"}, MSG_TYPE_STATUS)
|
||||
elif inner_cmd == 3:
|
||||
distance_serial.write(LASER_OFF_CMD)
|
||||
laser_calibration_active = False
|
||||
safe_enqueue({"result": "laser_off"}, MSG_TYPE_STATUS)
|
||||
elif inner_cmd == 4:
|
||||
voltage = get_bus_voltage()
|
||||
battery_percent = voltage_to_percent(voltage)
|
||||
battery_data = {"battery": battery_percent, "voltage": round(voltage, 3)}
|
||||
safe_enqueue(battery_data, MSG_TYPE_STATUS)
|
||||
elif inner_cmd == 5:
|
||||
ssid = inner_data.get("ssid")
|
||||
password = inner_data.get("password")
|
||||
if not ssid or not password:
|
||||
safe_enqueue({"result": "missing_ssid_or_password"}, MSG_TYPE_STATUS)
|
||||
else:
|
||||
# global update_thread_started
|
||||
if not update_thread_started:
|
||||
update_thread_started = True
|
||||
_thread.start_new_thread(handle_wifi_and_update, (ssid, password))
|
||||
else:
|
||||
safe_enqueue({"result": "update_already_started"}, MSG_TYPE_STATUS)
|
||||
elif inner_cmd == 6:
|
||||
try:
|
||||
ip = os.popen("ifconfig wlan0 2>/dev/null | grep 'inet ' | awk '{print $2}'").read().strip()
|
||||
ip = ip if ip else "no_ip"
|
||||
except:
|
||||
ip = "error_getting_ip"
|
||||
safe_enqueue({"result": "current_ip", "ip": ip}, MSG_TYPE_STATUS)
|
||||
|
||||
elif inner_cmd == 7:
|
||||
# global update_thread_started
|
||||
if update_thread_started:
|
||||
safe_enqueue({"result": "update_already_started"}, MSG_TYPE_STATUS)
|
||||
continue
|
||||
|
||||
# 实时检查是否有 IP
|
||||
try:
|
||||
ip = os.popen("ifconfig wlan0 2>/dev/null | grep 'inet ' | awk '{print $2}'").read().strip()
|
||||
except:
|
||||
ip = None
|
||||
|
||||
if not ip:
|
||||
safe_enqueue({"result": "ota_rejected", "reason": "no_wifi_ip"}, MSG_TYPE_STATUS)
|
||||
else:
|
||||
# 启动纯下载线程
|
||||
update_thread_started = True
|
||||
_thread.start_new_thread(direct_ota_download, ())
|
||||
rx_buf = rx_buf[match.end():]
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 解析/处理数据包失败: {e}")
|
||||
rx_buf = b""
|
||||
break
|
||||
|
||||
# 发送队列处理
|
||||
msg_type = None
|
||||
if logged_in:
|
||||
with send_queue_lock:
|
||||
if send_queue:
|
||||
msg_type, data_dict = send_queue.pop(0)
|
||||
if msg_type is not None:
|
||||
pkt = make_packet(msg_type, data_dict)
|
||||
if not tcp_send_raw(pkt):
|
||||
print("💔 发送失败,断开重连")
|
||||
break
|
||||
|
||||
# 校准结果上报
|
||||
if logged_in:
|
||||
x = y = None
|
||||
with laser_calibration_data_lock:
|
||||
if laser_calibration_result is not None:
|
||||
x, y = laser_calibration_result
|
||||
laser_calibration_result = None
|
||||
if x is not None:
|
||||
safe_enqueue({"result": "ok", "x": x, "y": y}, MSG_TYPE_STATUS)
|
||||
|
||||
# 心跳机制
|
||||
current_time = time.ticks_ms()
|
||||
if logged_in and current_time - last_heartbeat_send_time > HEARTBEAT_INTERVAL * 1000:
|
||||
if not tcp_send_raw(make_packet(MSG_TYPE_HEARTBEAT, {"t": int(time.time())})):
|
||||
print("💔 心跳发送失败")
|
||||
break
|
||||
last_heartbeat_send_time = current_time
|
||||
|
||||
if logged_in and current_time - last_heartbeat_ack_time > 6000:
|
||||
print("⏰ 6秒无心跳ACK,重连")
|
||||
break
|
||||
|
||||
time.sleep_ms(50)
|
||||
|
||||
tcp_connected = False
|
||||
time.sleep_ms(2000)
|
||||
|
||||
|
||||
def laser_calibration_worker():
|
||||
"""后台激光校准线程"""
|
||||
global laser_calibration_active, laser_calibration_result
|
||||
while True:
|
||||
if laser_calibration_active:
|
||||
result = calibrate_laser_position()
|
||||
if result and len(result) == 2:
|
||||
with laser_calibration_data_lock:
|
||||
laser_calibration_result = result
|
||||
laser_calibration_active = False
|
||||
print(f"✅ 后台校准成功: {result}")
|
||||
else:
|
||||
time.sleep_ms(80)
|
||||
else:
|
||||
time.sleep_ms(50)
|
||||
|
||||
|
||||
# ==============================
|
||||
# 主程序入口
|
||||
# ==============================
|
||||
|
||||
def cmd_str():
|
||||
global DEVICE_ID, PASSWORD
|
||||
DEVICE_ID = read_device_id()
|
||||
PASSWORD = DEVICE_ID + "."
|
||||
|
||||
photo_dir = "/root/phot"
|
||||
if photo_dir not in os.listdir("/root"):
|
||||
try:
|
||||
os.mkdir(photo_dir)
|
||||
except:
|
||||
pass
|
||||
|
||||
init_ina226()
|
||||
load_laser_point()
|
||||
|
||||
disp = display.Display()
|
||||
cam = camera.Camera(640, 480)
|
||||
|
||||
_thread.start_new_thread(tcp_main, ())
|
||||
_thread.start_new_thread(laser_calibration_worker, ())
|
||||
|
||||
print("系统准备完成...")
|
||||
|
||||
while not app.need_exit():
|
||||
if adc_obj.read() > ADC_TRIGGER_THRESHOLD:
|
||||
time.sleep_ms(60)
|
||||
frame = cam.read()
|
||||
|
||||
x, y = laser_point
|
||||
frame.draw_line(int(x - length), int(y), int(x + length), int(y), color, thickness)
|
||||
frame.draw_line(int(x), int(y - length), int(x), int(y + length), color, thickness)
|
||||
frame.draw_circle(int(x), int(y), 1, color, thickness)
|
||||
|
||||
result_img, center, radius, method, _ = detect_circle(frame)
|
||||
disp.show(result_img)
|
||||
|
||||
dx, dy = compute_laser_position(center, (x, y), radius, method)
|
||||
distance_m = read_distance_from_laser_sensor()
|
||||
voltage = get_bus_voltage()
|
||||
battery_percent = voltage_to_percent(voltage)
|
||||
|
||||
try:
|
||||
jpg_count = len([f for f in os.listdir(photo_dir) if f.endswith('.jpg')])
|
||||
filename = f"{photo_dir}/{int(x)}_{int(y)}_{round((distance_m or 0.0) * 100)}_{method}_{jpg_count:04d}.jpg"
|
||||
result_img.save(filename, quality=70)
|
||||
except Exception as e:
|
||||
print(f"❌ 保存照片失败: {e}")
|
||||
|
||||
inner_data = {
|
||||
"x": float(dx) if dx is not None else 200.0,
|
||||
"y": float(dy) if dy is not None else 200.0,
|
||||
"r": 90.0,
|
||||
"d": round((distance_m or 0.0) * 100),
|
||||
"m": method
|
||||
}
|
||||
report_data = {"cmd": 1, "data": inner_data}
|
||||
safe_enqueue(report_data, MSG_TYPE_STATUS)
|
||||
|
||||
time.sleep_ms(100)
|
||||
else:
|
||||
disp.show(cam.read())
|
||||
time.sleep_ms(50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cmd_str()
|
||||
1272
laser_manager.py
Normal file
1272
laser_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
212
logger_manager.py
Normal file
212
logger_manager.py
Normal file
@@ -0,0 +1,212 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
日志管理器模块
|
||||
提供异步日志功能(使用 QueueHandler + QueueListener)
|
||||
"""
|
||||
import logging
|
||||
from logging.handlers import QueueHandler, QueueListener, RotatingFileHandler
|
||||
import queue
|
||||
import os
|
||||
import config
|
||||
from version import VERSION
|
||||
|
||||
|
||||
class LoggerManager:
|
||||
"""日志管理器(单例)"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(LoggerManager, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 私有状态
|
||||
self._log_queue = None
|
||||
self._queue_listener = None
|
||||
self._logger = None
|
||||
|
||||
self._initialized = True
|
||||
|
||||
# ==================== 状态访问(只读属性)====================
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
"""获取logger对象(只读)"""
|
||||
return self._logger
|
||||
|
||||
@property
|
||||
def log_queue(self):
|
||||
"""获取日志队列(只读)"""
|
||||
return self._log_queue
|
||||
|
||||
# ==================== 业务方法 ====================
|
||||
|
||||
def init_logging(self, log_level=logging.INFO, log_file=None, max_bytes=None, backup_count=None):
|
||||
"""
|
||||
初始化异步日志系统(使用 QueueHandler + QueueListener)
|
||||
|
||||
Args:
|
||||
log_level: 日志级别,默认 INFO
|
||||
log_file: 日志文件路径,默认使用 config.LOG_FILE
|
||||
max_bytes: 单个日志文件最大大小(字节),默认使用 config.LOG_MAX_BYTES
|
||||
backup_count: 保留的备份文件数量,默认使用 config.LOG_BACKUP_COUNT
|
||||
"""
|
||||
if log_file is None:
|
||||
log_file = config.LOG_FILE
|
||||
if max_bytes is None:
|
||||
max_bytes = config.LOG_MAX_BYTES
|
||||
if backup_count is None:
|
||||
backup_count = config.LOG_BACKUP_COUNT
|
||||
|
||||
try:
|
||||
# 创建日志队列(无界队列)
|
||||
self._log_queue = queue.Queue(-1)
|
||||
|
||||
# 确保日志文件所在的目录存在
|
||||
log_dir = os.path.dirname(log_file)
|
||||
if log_dir: # 如果日志路径包含目录
|
||||
try:
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
except Exception as e:
|
||||
print(f"[WARN] 无法创建日志目录 {log_dir}: {e}")
|
||||
|
||||
# 尝试创建文件Handler(带日志轮转)
|
||||
try:
|
||||
file_handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=max_bytes,
|
||||
backupCount=backup_count,
|
||||
encoding='utf-8',
|
||||
mode='a' # 追加模式,确保不覆盖
|
||||
)
|
||||
except Exception as e:
|
||||
# 如果RotatingFileHandler不可用,降级为普通FileHandler
|
||||
print(f"[WARN] RotatingFileHandler不可用,使用普通FileHandler: {e}")
|
||||
try:
|
||||
file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='a')
|
||||
except Exception as e2:
|
||||
# 如果文件Handler创建失败,只使用控制台Handler
|
||||
print(f"[WARN] 无法创建文件Handler,仅使用控制台输出: {e2}")
|
||||
file_handler = None
|
||||
|
||||
# 自定义Formatter,包含版本信息
|
||||
class CustomFormatter(logging.Formatter):
|
||||
"""自定义日志格式,包含版本信息和行号"""
|
||||
def format(self, record):
|
||||
record.version = VERSION
|
||||
return super().format(record)
|
||||
|
||||
# 如果file_handler存在,设置格式和级别
|
||||
if file_handler is not None:
|
||||
file_handler.setFormatter(CustomFormatter(
|
||||
'%(asctime)s [v%(version)s] [%(levelname)s] %(filename)s:%(lineno)d - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
))
|
||||
file_handler.setLevel(log_level)
|
||||
|
||||
# 创建控制台Handler(保留原有的print输出)
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(CustomFormatter(
|
||||
'[v%(version)s] [%(levelname)s] %(filename)s:%(lineno)d - %(message)s'
|
||||
))
|
||||
console_handler.setLevel(log_level)
|
||||
|
||||
# 创建QueueListener(后台线程处理日志写入)
|
||||
# 如果file_handler为None,只使用console_handler
|
||||
handlers = [console_handler]
|
||||
if file_handler is not None:
|
||||
handlers.append(file_handler)
|
||||
|
||||
self._queue_listener = QueueListener(
|
||||
self._log_queue,
|
||||
*handlers,
|
||||
respect_handler_level=True
|
||||
)
|
||||
self._queue_listener.start()
|
||||
|
||||
# 创建QueueHandler(用于记录日志)
|
||||
queue_handler = QueueHandler(self._log_queue)
|
||||
|
||||
# 配置根logger
|
||||
self._logger = logging.getLogger()
|
||||
self._logger.addHandler(queue_handler)
|
||||
self._logger.setLevel(log_level)
|
||||
|
||||
# 避免日志向上传播到其他logger
|
||||
self._logger.propagate = False
|
||||
|
||||
# 添加启动标记
|
||||
self._logger.info("=" * 60)
|
||||
self._logger.info("程序启动 - 日志系统初始化")
|
||||
self._logger.info(f"版本: {VERSION}")
|
||||
self._logger.info(f"日志文件: {log_file}")
|
||||
self._logger.info("=" * 60)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
# 如果日志初始化失败,至少保证程序能运行
|
||||
print(f"[ERROR] 日志系统初始化失败: {e}")
|
||||
import traceback
|
||||
try:
|
||||
traceback.print_exc()
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
def stop_logging(self):
|
||||
"""停止日志系统(程序退出时调用)"""
|
||||
try:
|
||||
if self._logger:
|
||||
# 确保所有日志都写入
|
||||
self._logger.info("程序退出,正在保存日志...")
|
||||
import time as std_time
|
||||
std_time.sleep(0.5) # 给一点时间让日志写入
|
||||
|
||||
if self._queue_listener:
|
||||
self._queue_listener.stop()
|
||||
|
||||
if self._logger:
|
||||
# 等待队列中的日志处理完成
|
||||
if self._log_queue:
|
||||
import time as std_time
|
||||
timeout = 5
|
||||
start = std_time.time()
|
||||
while not self._log_queue.empty() and (std_time.time() - start) < timeout:
|
||||
std_time.sleep(0.1)
|
||||
print("[LOG] 日志系统已停止")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 停止日志系统失败: {e}")
|
||||
|
||||
|
||||
# 创建全局单例实例
|
||||
logger_manager = LoggerManager()
|
||||
|
||||
# ==================== 向后兼容的函数接口 ====================
|
||||
|
||||
def init_logging(log_level=logging.INFO, log_file=None, max_bytes=None, backup_count=None):
|
||||
"""初始化日志系统(向后兼容接口)"""
|
||||
return logger_manager.init_logging(log_level, log_file, max_bytes, backup_count)
|
||||
|
||||
def stop_logging():
|
||||
"""停止日志系统(向后兼容接口)"""
|
||||
return logger_manager.stop_logging()
|
||||
|
||||
def get_logger():
|
||||
"""
|
||||
获取全局logger对象(向后兼容接口)
|
||||
如果日志系统未初始化,返回None(此时可以使用print作为fallback)
|
||||
"""
|
||||
return logger_manager.logger
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
BIN
model_270139.cvimodel
Normal file
BIN
model_270139.cvimodel
Normal file
Binary file not shown.
13
model_270139.mud
Normal file
13
model_270139.mud
Normal file
@@ -0,0 +1,13 @@
|
||||
|
||||
[basic]
|
||||
type = cvimodel
|
||||
model = model_270139.cvimodel
|
||||
|
||||
[extra]
|
||||
model_type = yolov5
|
||||
input_type = rgb
|
||||
mean = 0, 0, 0
|
||||
scale = 0.00392156862745098, 0.00392156862745098, 0.00392156862745098
|
||||
anchors = 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326
|
||||
labels = 黑三角和圆环
|
||||
|
||||
BIN
model_270820.cvimodel
Normal file
BIN
model_270820.cvimodel
Normal file
Binary file not shown.
13
model_270820.mud
Normal file
13
model_270820.mud
Normal file
@@ -0,0 +1,13 @@
|
||||
|
||||
[basic]
|
||||
type = cvimodel
|
||||
model = model_270820.cvimodel
|
||||
|
||||
[extra]
|
||||
model_type = yolov5
|
||||
input_type = rgb
|
||||
mean = 0, 0, 0
|
||||
scale = 0.00392156862745098, 0.00392156862745098, 0.00392156862745098
|
||||
anchors = 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326
|
||||
labels = triangle
|
||||
|
||||
2151
network.py
Normal file
2151
network.py
Normal file
File diff suppressed because it is too large
Load Diff
1343
ota_manager.py
Normal file
1343
ota_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
230
package.py
Normal file
230
package.py
Normal file
@@ -0,0 +1,230 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
应用打包脚本
|
||||
根据 app.yaml 中列出的文件,打包成 zip 文件
|
||||
版本号从 version.py 中读取
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import yaml
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
import sys
|
||||
import secrets
|
||||
|
||||
MAGIC = b"AROTAE1" # 7 bytes: Archery OTA Encrypted v1
|
||||
GCM_NONCE_LEN = 12
|
||||
GCM_TAG_LEN = 16
|
||||
|
||||
# 添加当前目录到路径,以便导入 version 模块
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def load_app_yaml(yaml_path='app.yaml'):
|
||||
"""加载 app.yaml 文件"""
|
||||
try:
|
||||
with open(yaml_path, 'r', encoding='utf-8') as f:
|
||||
return yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 读取 {yaml_path} 失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def check_files_exist(files, base_dir='.'):
|
||||
"""检查文件是否存在"""
|
||||
missing_files = []
|
||||
existing_files = []
|
||||
|
||||
for file_path in files:
|
||||
full_path = os.path.join(base_dir, file_path)
|
||||
if os.path.exists(full_path):
|
||||
existing_files.append(file_path)
|
||||
else:
|
||||
missing_files.append(file_path)
|
||||
|
||||
return existing_files, missing_files
|
||||
|
||||
|
||||
def get_version_from_version_py():
|
||||
"""从 version.py 读取版本号"""
|
||||
try:
|
||||
from version import VERSION
|
||||
return VERSION
|
||||
except ImportError:
|
||||
print("[WARNING] 无法导入 version.py,使用默认版本号 1.0.0")
|
||||
return '1.0.0'
|
||||
except Exception as e:
|
||||
print(f"[WARNING] 读取 version.py 失败: {e},使用默认版本号 1.0.0")
|
||||
return '1.0.0'
|
||||
|
||||
|
||||
def create_zip_package(app_info, files, output_dir='.', base_dir='.'):
|
||||
"""创建 zip 打包文件"""
|
||||
# 生成输出文件名:{name}_v{version}_{timestamp}.zip
|
||||
# 版本号从 version.py 读取,而不是从 app.yaml
|
||||
app_name = app_info.get('name', 'app')
|
||||
version = get_version_from_version_py() # 从 version.py 读取版本号
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
zip_filename = f"{app_name}_v{version}_{timestamp}.zip"
|
||||
zip_path = os.path.join(output_dir, zip_filename)
|
||||
|
||||
print(f"[INFO] 开始打包: {zip_filename}")
|
||||
print(f"[INFO] 包含文件数: {len(files)}")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||
for file_path in files:
|
||||
full_path = os.path.join(base_dir, file_path)
|
||||
# 使用相对路径作为 zip 内的路径
|
||||
zipf.write(full_path, file_path)
|
||||
print(f" ✓ {file_path}")
|
||||
|
||||
# 获取文件大小
|
||||
file_size = os.path.getsize(zip_path)
|
||||
file_size_mb = file_size / (1024 * 1024)
|
||||
|
||||
print(f"\n[SUCCESS] 打包完成!")
|
||||
print(f" 文件名: {zip_filename}")
|
||||
print(f" 文件大小: {file_size_mb:.2f} MB ({file_size:,} 字节)")
|
||||
print(f" 文件路径: {os.path.abspath(zip_path)}")
|
||||
|
||||
return zip_path
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 打包失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
def _validate_key_hex(key_hex: str) -> bytes:
|
||||
if not isinstance(key_hex, str):
|
||||
raise ValueError("aead key must be hex string")
|
||||
key_hex = key_hex.strip().lower()
|
||||
if key_hex.startswith("0x"):
|
||||
key_hex = key_hex[2:]
|
||||
if len(key_hex) != 64:
|
||||
raise ValueError("aead key must be 64 hex chars (32 bytes)")
|
||||
try:
|
||||
key = bytes.fromhex(key_hex)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid hex key: {e}")
|
||||
if len(key) != 32:
|
||||
raise ValueError("aead key must be 32 bytes")
|
||||
return key
|
||||
|
||||
|
||||
def encrypt_zip_aead(zip_path: str, key_hex: str, out_ext: str = ".enc") -> str:
|
||||
"""
|
||||
Encrypt the whole zip file as one blob:
|
||||
output format: MAGIC(7) | nonce(12) | ciphertext(N) | tag(16)
|
||||
using AES-256-GCM (AEAD).
|
||||
"""
|
||||
# Lazy import: packaging-only dependency
|
||||
try:
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Missing dependency: cryptography. Install with: pip install cryptography. "
|
||||
f"Import error: {e}"
|
||||
)
|
||||
|
||||
key = _validate_key_hex(key_hex)
|
||||
with open(zip_path, "rb") as f:
|
||||
plain = f.read()
|
||||
|
||||
nonce = secrets.token_bytes(GCM_NONCE_LEN)
|
||||
aesgcm = AESGCM(key)
|
||||
ct_and_tag = aesgcm.encrypt(nonce, plain, None) # ciphertext || tag (16 bytes)
|
||||
|
||||
enc_path = zip_path + out_ext if out_ext else (zip_path + ".enc")
|
||||
with open(enc_path, "wb") as f:
|
||||
f.write(MAGIC)
|
||||
f.write(nonce)
|
||||
f.write(ct_and_tag)
|
||||
|
||||
return enc_path
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description="打包 app.yaml 文件列表到 zip,并可选进行 AES-256-GCM 加密输出 .enc")
|
||||
parser.add_argument("--aead-key-hex", default=None, help="AES-256-GCM key (64 hex chars = 32 bytes). If set, output encrypted file.")
|
||||
parser.add_argument("--keep-zip", action="store_true", help="Keep the plaintext zip when encryption is enabled.")
|
||||
parser.add_argument("--out-ext", default=".enc", help="Encrypted output extension appended to zip path. Default: .enc (produces *.zip.enc)")
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 60)
|
||||
print("应用打包脚本")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 加载 app.yaml
|
||||
app_info = load_app_yaml('app.yaml')
|
||||
if app_info is None:
|
||||
return
|
||||
|
||||
# 从 version.py 读取版本号
|
||||
version = get_version_from_version_py()
|
||||
|
||||
print(f"\n[INFO] 应用信息:")
|
||||
print(f" ID: {app_info.get('id', 'N/A')}")
|
||||
print(f" 名称: {app_info.get('name', 'N/A')}")
|
||||
print(f" 版本: {version} (来自 version.py)")
|
||||
print(f" 作者: {app_info.get('author', 'N/A')}")
|
||||
if app_info.get('version') != version:
|
||||
print(f" [注意] app.yaml 中的版本 ({app_info.get('version', 'N/A')}) 与 version.py 不一致")
|
||||
|
||||
# 2. 获取文件列表
|
||||
files = app_info.get('files', [])
|
||||
if not files:
|
||||
print("[ERROR] app.yaml 中没有找到 files 列表")
|
||||
return
|
||||
|
||||
print(f"\n[INFO] 文件列表 ({len(files)} 个文件):")
|
||||
|
||||
# 3. 检查文件是否存在
|
||||
existing_files, missing_files = check_files_exist(files)
|
||||
|
||||
if missing_files:
|
||||
print(f"\n[WARNING] 以下文件不存在,将被跳过:")
|
||||
for f in missing_files:
|
||||
print(f" ✗ {f}")
|
||||
|
||||
if not existing_files:
|
||||
print("\n[ERROR] 没有找到任何有效文件,无法打包")
|
||||
return
|
||||
|
||||
print(f"\n[INFO] 找到 {len(existing_files)} 个有效文件")
|
||||
|
||||
# 4. 创建 zip 包
|
||||
zip_path = create_zip_package(app_info, existing_files)
|
||||
|
||||
if zip_path:
|
||||
enc_path = None
|
||||
if args.aead_key_hex:
|
||||
try:
|
||||
enc_path = encrypt_zip_aead(zip_path, args.aead_key_hex, out_ext=args.out_ext)
|
||||
enc_size = os.path.getsize(enc_path)
|
||||
print(f"\n[SUCCESS] AEAD加密完成: {os.path.basename(enc_path)} ({enc_size:,} bytes)")
|
||||
print(f" 文件路径: {os.path.abspath(enc_path)}")
|
||||
if not args.keep_zip:
|
||||
try:
|
||||
os.remove(zip_path)
|
||||
print(f"[INFO] 已删除明文zip: {os.path.basename(zip_path)}")
|
||||
except Exception as e:
|
||||
print(f"[WARNING] 删除明文zip失败(可忽略): {e}")
|
||||
except Exception as e:
|
||||
print(f"\n[ERROR] AEAD加密失败: {e}")
|
||||
print("[ERROR] 保留明文zip用于排查。")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("打包成功完成!")
|
||||
print("=" * 60)
|
||||
else:
|
||||
print("\n" + "=" * 60)
|
||||
print("打包失败!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
254
power.py
Normal file
254
power.py
Normal file
@@ -0,0 +1,254 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
电源管理模块(INA226)
|
||||
提供电压、电流监测和充电状态检测
|
||||
"""
|
||||
import config
|
||||
from logger_manager import logger_manager
|
||||
from maix import time as maix_time
|
||||
|
||||
|
||||
_INA226_PRESENT = None
|
||||
|
||||
|
||||
def _ina226_ready() -> bool:
|
||||
"""
|
||||
是否允许访问 INA226。
|
||||
|
||||
重要:
|
||||
- 这里刻意不做任何 I2C 探测/读写。
|
||||
- 经验上,在 INA226 未供电/未响应时,I2C 的 readfrom_mem 可能直接触发底层崩溃(SIGSEGV),try/except 无法拦截。
|
||||
- 因此只在开机 init_ina226() 成功后才允许后续读电压/电流。
|
||||
"""
|
||||
return bool(getattr(config, "INA226_ENABLE", True)) and (_INA226_PRESENT is True)
|
||||
|
||||
|
||||
def write_register(reg, value):
|
||||
"""写入INA226寄存器"""
|
||||
from hardware import hardware_manager
|
||||
logger = logger_manager.logger
|
||||
data = [(value >> 8) & 0xFF, value & 0xFF]
|
||||
# 某些底层驱动在失败时只打印 “write failed” 并返回 -1,而不是抛异常;
|
||||
# 为避免误判“初始化成功”导致后续 readfrom_mem SIGSEGV,这里把失败显式转成异常。
|
||||
ret = hardware_manager.bus.writeto_mem(config.INA226_ADDR, reg, bytes(data))
|
||||
if isinstance(ret, int) and ret < 0:
|
||||
if logger:
|
||||
logger.error(f"[INA226] writeto_mem 失败: addr=0x{config.INA226_ADDR:02X} reg=0x{reg:02X} ret={ret}")
|
||||
raise OSError(ret)
|
||||
|
||||
|
||||
def read_register(reg):
|
||||
"""读取INA226寄存器"""
|
||||
from hardware import hardware_manager
|
||||
data = hardware_manager.bus.readfrom_mem(config.INA226_ADDR, reg, 2)
|
||||
return (data[0] << 8) | data[1]
|
||||
|
||||
|
||||
def init_ina226():
|
||||
"""初始化 INA226 芯片:配置模式 + 校准值"""
|
||||
global _INA226_PRESENT
|
||||
logger = logger_manager.logger
|
||||
if not getattr(config, "INA226_ENABLE", True):
|
||||
if logger:
|
||||
logger.info("[INA226] INA226_ENABLE=False,跳过初始化与 I2C 探测")
|
||||
# 显式标记不可用,避免后续误读
|
||||
_INA226_PRESENT = False
|
||||
return False
|
||||
try:
|
||||
# 仅通过“写寄存器成功”来判定可用,避免额外的读操作触发底层崩溃
|
||||
write_register(config.REG_CONFIGURATION, 0x4527)
|
||||
write_register(config.REG_CALIBRATION, config.CALIBRATION_VALUE)
|
||||
_INA226_PRESENT = True
|
||||
return True
|
||||
except Exception as e:
|
||||
_INA226_PRESENT = False
|
||||
if logger:
|
||||
logger.error(f"[INA226] 初始化失败:{e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_bus_voltage():
|
||||
"""读取总线电压(单位:V)。未探测到 INA226 或读失败时返回 0.0(上报用,避免 null)。"""
|
||||
logger = logger_manager.logger
|
||||
if not _ina226_ready():
|
||||
return 0.0
|
||||
try:
|
||||
raw = read_register(config.REG_BUS_VOLTAGE)
|
||||
return raw * 1.25 / 1000
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.error(f"[INA226] 读取电压失败:{e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def get_current():
|
||||
"""
|
||||
读取电流(单位:mA)
|
||||
正数表示充电,负数表示放电
|
||||
|
||||
INA226 电流计算公式:
|
||||
Current = (Current Register Value) × Current_LSB
|
||||
Current_LSB = 0.001 × CALIBRATION_VALUE / 4096
|
||||
"""
|
||||
try:
|
||||
if not _ina226_ready():
|
||||
return 0.0
|
||||
raw = read_register(config.REG_CURRENT)
|
||||
# INA226 电流寄存器是16位有符号整数
|
||||
# 最高位是符号位:0=正(充电),1=负(放电)
|
||||
# 计算 Current_LSB(根据 CALIBRATION_VALUE)
|
||||
current_lsb = 0.001 * config.CALIBRATION_VALUE / 4096 # 单位:A
|
||||
# 处理有符号数:如果最高位为1,转换为负数
|
||||
if raw & 0x8000: # 最高位为1,表示负数(放电)
|
||||
signed_raw = raw - 0x10000 # 转换为有符号整数
|
||||
else: # 最高位为0,表示正数(充电)
|
||||
signed_raw = raw
|
||||
# 转换为毫安
|
||||
current_ma = signed_raw * current_lsb * 1000
|
||||
return current_ma
|
||||
except Exception as e:
|
||||
logger = logger_manager.logger
|
||||
if logger:
|
||||
logger.error(f"[INA226] 读取电流失败: {e}")
|
||||
else:
|
||||
print(f"[INA226] 读取电流失败: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def is_charging(threshold_ma=10.0):
|
||||
"""
|
||||
检测是否在充电(通过电流方向判断)
|
||||
|
||||
Args:
|
||||
threshold_ma: 电流阈值(毫安),超过此值认为在充电,默认10mA
|
||||
|
||||
Returns:
|
||||
True: 正在充电
|
||||
False: 未充电或读取失败
|
||||
"""
|
||||
try:
|
||||
current = get_current()
|
||||
is_charge = current > threshold_ma
|
||||
return is_charge
|
||||
except Exception as e:
|
||||
logger = logger_manager.logger
|
||||
if logger:
|
||||
logger.error(f"[CHARGE] 检测充电状态失败: {e}")
|
||||
else:
|
||||
print(f"[CHARGE] 检测充电状态失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def voltage_to_percent(voltage):
|
||||
"""
|
||||
根据电压估算电池百分比(高密度查表插值 + 滤波)。
|
||||
|
||||
- 电压先做 5 点移动平均(抑制瞬时抖动)
|
||||
- SOC 再做一阶低通(抑制“跳电量”)
|
||||
|
||||
注意:
|
||||
- 该方法仍是“开路电压→SOC”的近似;负载较大/瞬时大电流时电压会下沉,SOC 会偏低。
|
||||
- 滤波会带来滞后:电量变化会更平滑,但更新更慢。
|
||||
"""
|
||||
if voltage is None:
|
||||
return 0
|
||||
try:
|
||||
v = float(voltage)
|
||||
except Exception:
|
||||
return 0
|
||||
if v <= 0:
|
||||
return 0
|
||||
return int(int(_BATTERY_MONITOR.get_soc(v) * 10) / 10) # 截断而不是四舍五入
|
||||
|
||||
|
||||
class BatteryMonitor:
|
||||
"""
|
||||
电压→SOC 估算器(查表 + 线性插值 + 双重滤波)。
|
||||
|
||||
说明:
|
||||
- 表为单节锂电“静态电压”近似曲线;不同电池/温度/老化会有偏差。
|
||||
- 这里不区分充电/放电曲线(滞后),主要用于“显示电量/粗略判断”。
|
||||
"""
|
||||
|
||||
def __init__(self, avg_window: int = 5, alpha: float = 0.2):
|
||||
# 电压-SOC对照表(电压从高到低)
|
||||
self.voltages = [
|
||||
4.20, 4.15, 4.10, 4.05, 4.00,
|
||||
3.95, 3.90, 3.88, 3.85, 3.82,
|
||||
3.80, 3.78, 3.75, 3.72, 3.70,
|
||||
3.65, 3.60, 3.55, 3.50, 3.45,
|
||||
3.40, 3.35, 3.30, 3.20, 2.50,
|
||||
]
|
||||
self.socs = [
|
||||
100, 98, 95, 90, 85,
|
||||
80, 75, 72, 68, 64,
|
||||
60, 56, 52, 48, 44,
|
||||
38, 32, 26, 20, 14,
|
||||
10, 6, 3, 1, 0,
|
||||
]
|
||||
|
||||
self.avg_window = max(1, int(avg_window))
|
||||
self.alpha = float(alpha) if alpha is not None else 0.2
|
||||
if not (0.0 < self.alpha <= 1.0):
|
||||
self.alpha = 0.2
|
||||
|
||||
self.voltage_history = []
|
||||
self.last_soc = 50.0
|
||||
|
||||
def _voltage_to_soc_raw(self, voltage: float) -> float:
|
||||
# 越界
|
||||
if voltage >= self.voltages[0]:
|
||||
return 100.0
|
||||
if voltage <= self.voltages[-1]:
|
||||
return 0.0
|
||||
|
||||
# 表是降序,二分查找
|
||||
left, right = 0, len(self.voltages) - 1
|
||||
while left <= right:
|
||||
mid = (left + right) // 2
|
||||
vm = self.voltages[mid]
|
||||
if vm == voltage:
|
||||
return float(self.socs[mid])
|
||||
elif vm < voltage:
|
||||
right = mid - 1
|
||||
else:
|
||||
left = mid + 1
|
||||
|
||||
# 线性插值:right 在高电压侧,left 在低电压侧(降序表)
|
||||
# 例:voltages = [4.2,4.15,...],则 v_high=voltages[right] >= voltage >= voltages[left]=v_low
|
||||
v_high, v_low = float(self.voltages[right]), float(self.voltages[left])
|
||||
soc_high, soc_low = float(self.socs[right]), float(self.socs[left])
|
||||
if abs(v_high - v_low) < 1e-9:
|
||||
return soc_low
|
||||
soc = soc_low + (voltage - v_low) * (soc_high - soc_low) / (v_high - v_low)
|
||||
return soc
|
||||
|
||||
def get_soc(self, raw_voltage: float) -> float:
|
||||
# 1) 电压滤波(移动平均)
|
||||
self.voltage_history.append(float(raw_voltage))
|
||||
if len(self.voltage_history) > self.avg_window:
|
||||
self.voltage_history.pop(0)
|
||||
voltage = sum(self.voltage_history) / float(len(self.voltage_history))
|
||||
|
||||
# 2) 查表插值
|
||||
raw_soc = self._voltage_to_soc_raw(voltage)
|
||||
|
||||
# 3) SOC 低通滤波
|
||||
a = self.alpha
|
||||
self.last_soc = a * raw_soc + (1.0 - a) * float(self.last_soc)
|
||||
|
||||
# clip
|
||||
if self.last_soc < 0.0:
|
||||
self.last_soc = 0.0
|
||||
if self.last_soc > 100.0:
|
||||
self.last_soc = 100.0
|
||||
return float(self.last_soc)
|
||||
|
||||
|
||||
# 模块级单例:保留历史,实现平滑(进程重启会重置)
|
||||
_BATTERY_MONITOR = BatteryMonitor(
|
||||
avg_window=int(getattr(config, "BATTERY_SOC_AVG_WINDOW", 5)),
|
||||
alpha=float(getattr(config, "BATTERY_SOC_LPF_ALPHA", 0.2)),
|
||||
)
|
||||
|
||||
33
server.pem
Normal file
33
server.pem
Normal file
@@ -0,0 +1,33 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIFwjCCA6qgAwIBAgIUAZIGjFLTekYI+IIquQ/87qLDuNAwDQYJKoZIhvcNAQEL
|
||||
BQAwXjELMAkGA1UEBhMCQ04xDjAMBgNVBAgMBUxvY2FsMQ4wDAYDVQQHDAVMb2Nh
|
||||
bDEOMAwGA1UECgwFTG9jYWwxHzAdBgNVBAMMFnd3dy5zaGVsaW5neGluZ3FpdS5j
|
||||
b20wIBcNMjYwNDA3MDc0NDI2WhgPMjEyNjAzMTQwNzQ0MjZaMF4xCzAJBgNVBAYT
|
||||
AkNOMQ4wDAYDVQQIDAVMb2NhbDEOMAwGA1UEBwwFTG9jYWwxDjAMBgNVBAoMBUxv
|
||||
Y2FsMR8wHQYDVQQDDBZ3d3cuc2hlbGluZ3hpbmdxaXUuY29tMIICIjANBgkqhkiG
|
||||
9w0BAQEFAAOCAg8AMIICCgKCAgEAvKRcWr8QeT1OzhMbWlHmqxmduE+e7r2Oet9I
|
||||
mU4O888U1X1YKaIDnq+zqRCNteid3jrOWucDLReZzNnrZ4l3Jq9nbWuTwj9Y9vCq
|
||||
ahW3K3BOhnuJ+qvqX2Izn1Z9iNCFhXnUaFy8+iP0nJNNIRXwg7ioKbY6+SaTbBzI
|
||||
vfG33MjOmwnQlqZzdGyNpvieO9XzqVyRxeDen/LJf4Z1NocP2rOjqQC3dIDXOfBt
|
||||
/ZOZymb4XwQ9b/t+6WJn9Zfycw0tp/7GqI+vqLDUMpipO4ahmybJPO02IhokZ09t
|
||||
BnCXe0enLnMAshIipTxSaJEick9HnQVSUzF+9A1F0cCFAhS8cM/04aksfYsJD2xj
|
||||
riiVHVoVo6tb0GJSCM+b0j9ObH9bDx3DKfy9EcqP25mJxWQTuT8G0oiyuxE5knjA
|
||||
HL7yjwd5gVSuig+ACnxE3vITeVKtvyep7sD4tJqkN93t7OMeBRFMGsYpJ8w+8u6X
|
||||
+9/RmMcOnuNcT/4HrOuAtlAnM1D44MSI1RLaOCJJ9evqhpWdktfn2Uv4gCnaTjUr
|
||||
OiEU/G+lquST2kggjbcReLqkk+7yN3XkaR9dun4iV35WfEo1ENThVhLPGV61LaJq
|
||||
PwbjltQlkcAFPJ1GJyE9FVO79bB51d0w/rlI/CcDUpTRMaXR35EmTjxvXOr/a/XI
|
||||
56GUNaUCAwEAAaN2MHQwHQYDVR0OBBYEFH1HCDm4N7LMhIX2Fb2FXAfdyhwQMB8G
|
||||
A1UdIwQYMBaAFH1HCDm4N7LMhIX2Fb2FXAfdyhwQMA8GA1UdEwEB/wQFMAMBAf8w
|
||||
IQYDVR0RBBowGIIWd3d3LnNoZWxpbmd4aW5ncWl1LmNvbTANBgkqhkiG9w0BAQsF
|
||||
AAOCAgEAG/PMwXCXJOaqCpU/LaY6w04ue6wk95RbPXf4JH4CrrLUfgyUmFlNNQPA
|
||||
LuZSBRI6KUGkTvzuz/3ofZHVEin3CyE5NadB3UItpfA4Wl4r3jMPifIgnA/NT8xo
|
||||
GE1gYaDbcfJNE8jy6GebjZekbVrPvCY9YgcUT2AmW5fcbnCTy+/iC7lf9MvvqHTJ
|
||||
H5zvOp5nyWJYWYsvvif3Y7dp00ytg9I8/LSgUspKwB8qSWPWV8z4WsV6sc1mNqVS
|
||||
nFBDkgzZxr4ZYlhVLzbSoab8D4A/z6riEMqv4S+oF5VkaJLhsN8vgHh9aPspCC3Q
|
||||
zhcosH8XmNmJmT/X64FhhRqxAqX65WanVQABtBS/vsC+FAQDGMb3RkZSbLEnIlgj
|
||||
bx/6bSkhHl+J2xIqA7tLvYhRSvM3H12X7VSVc+tkVzI5JoUSugZLxxRDGpYgkvRz
|
||||
SPFCqb9eTn5ES5gnQX6+E+f/E/WQTmadolSbEppdxNZW7AaIUdQo0aFxFwctwhA2
|
||||
YNUG9oW2TXAZjSECyTo28NFkFfwBhpHWigFCANNCd8Nrn0k0YMuJOkqW5e4w3/24
|
||||
/IxM/C9K7aAx4S1XZ16Nvh5pZQduEGKTSUYMJ/uV26Mf4ZGroUfGB9tBguK5rYbL
|
||||
UlRvtU9mkZPK04GbLsoo+8tZTDRtkuCiC19xk33XiitZrmavc24=
|
||||
-----END CERTIFICATE-----
|
||||
@@ -47,6 +47,7 @@ def set_autostart_app(app_id):
|
||||
if __name__ == "__main__":
|
||||
new_autostart_app_id = "t11" # change to app_id you want to set
|
||||
# new_autostart_app_id = None # remove autostart
|
||||
# new_autostart_app_id = "z1222" # change to app_id you want to set
|
||||
|
||||
list_apps()
|
||||
print("Before set autostart appid:", get_curr_autostart_app())
|
||||
|
||||
546
shoot_manager.py
Normal file
546
shoot_manager.py
Normal file
@@ -0,0 +1,546 @@
|
||||
import os
|
||||
import threading
|
||||
import time as time_std
|
||||
|
||||
import config
|
||||
from camera_manager import camera_manager
|
||||
from laser_manager import laser_manager
|
||||
from logger_manager import logger_manager
|
||||
from network import network_manager
|
||||
from triangle_target import load_camera_from_xml, load_triangle_positions, try_triangle_scoring
|
||||
from vision import estimate_distance, detect_circle_v3, enqueue_save_shot
|
||||
from maix import image, time
|
||||
|
||||
# 缓存相机标定与三角形位置,避免每次射箭重复读磁盘
|
||||
_tri_calib_cache = None
|
||||
|
||||
def _get_triangle_calib():
|
||||
"""返回 (K, dist, marker_positions);首次调用时从磁盘加载并缓存。"""
|
||||
global _tri_calib_cache
|
||||
if _tri_calib_cache is not None:
|
||||
return _tri_calib_cache
|
||||
calib_path = getattr(config, "CAMERA_CALIB_XML", "")
|
||||
tri_json = getattr(config, "TRIANGLE_POSITIONS_JSON", "")
|
||||
if not (os.path.isfile(calib_path) and os.path.isfile(tri_json)):
|
||||
_tri_calib_cache = (None, None, None)
|
||||
return _tri_calib_cache
|
||||
K, dist = load_camera_from_xml(calib_path)
|
||||
pos = load_triangle_positions(tri_json)
|
||||
_tri_calib_cache = (K, dist, pos)
|
||||
return _tri_calib_cache
|
||||
|
||||
|
||||
def preload_triangle_calib():
|
||||
"""
|
||||
启动阶段预加载三角形标定与坐标文件,避免首次射箭触发时的读盘/解析开销。
|
||||
"""
|
||||
try:
|
||||
_get_triangle_calib()
|
||||
except Exception:
|
||||
# 预加载失败不影响主流程;射箭时会再次按需尝试
|
||||
pass
|
||||
|
||||
|
||||
def analyze_shot(frame, laser_point=None):
|
||||
"""
|
||||
分析射箭结果(算法部分,可迁移到C++)
|
||||
:param frame: 图像帧
|
||||
:param laser_point: 激光点坐标 (x, y)
|
||||
:return: 包含分析结果的字典
|
||||
|
||||
优先级:
|
||||
1. 三角形单应性(USE_TRIANGLE_OFFSET=True 时)— 成功则直接返回,跳过圆形检测
|
||||
2. 圆形检测(三角形不可用或识别失败时兜底)
|
||||
"""
|
||||
logger = logger_manager.logger
|
||||
from datetime import datetime
|
||||
|
||||
# ── Step 1: 确定激光点 ────────────────────────────────────────────────────
|
||||
laser_point_method = None
|
||||
distance_m_first = None
|
||||
|
||||
if config.HARDCODE_LASER_POINT:
|
||||
laser_point = laser_manager.laser_point
|
||||
laser_point_method = "hardcode"
|
||||
elif laser_manager.has_calibrated_point():
|
||||
laser_point = laser_manager.laser_point
|
||||
laser_point_method = "calibrated"
|
||||
if logger:
|
||||
logger.info(f"[算法] 使用校准值: {laser_manager.laser_point}")
|
||||
else:
|
||||
# 动态模式:先做一次无激光点检测以估算距离,再推算激光点
|
||||
_, _, _, _, best_radius1_temp, _ = detect_circle_v3(frame, None)
|
||||
distance_m_first = estimate_distance(best_radius1_temp) if best_radius1_temp else None
|
||||
if distance_m_first and distance_m_first > 0:
|
||||
laser_point = laser_manager.calculate_laser_point_from_distance(distance_m_first)
|
||||
laser_point_method = "dynamic"
|
||||
if logger:
|
||||
logger.info(f"[算法] 使用比例尺: {laser_point}")
|
||||
else:
|
||||
laser_point = laser_manager.laser_point
|
||||
laser_point_method = "default"
|
||||
if logger:
|
||||
logger.info(f"[算法] 使用默认值: {laser_point}")
|
||||
|
||||
if laser_point is None:
|
||||
return {"success": False, "reason": "laser_point_not_initialized"}
|
||||
|
||||
x, y = laser_point
|
||||
|
||||
# ── Step 2: 提前转换一次图像,两个检测线程共享(只读)────────────────────────
|
||||
img_cv = image.image2cv(frame, False, False)
|
||||
|
||||
# ── Step 3: 检查三角形是否可用 ────────────────────────────────────────────────
|
||||
use_tri = getattr(config, "USE_TRIANGLE_OFFSET", False)
|
||||
K = dist_coef = pos = None
|
||||
if use_tri:
|
||||
K, dist_coef, pos = _get_triangle_calib()
|
||||
use_tri = K is not None and dist_coef is not None and pos
|
||||
|
||||
def _build_circle_result(cdata, yolo_roi_xyxy=None):
|
||||
"""从圆形检测结果构建 analyze_shot 返回值。"""
|
||||
r_img, center, radius, method, best_radius1, ellipse_params = cdata
|
||||
dx, dy = None, None
|
||||
d_m = distance_m_first
|
||||
if center and radius:
|
||||
dx, dy = laser_manager.compute_laser_position(center, (x, y), radius, method)
|
||||
d_m = estimate_distance(best_radius1) if best_radius1 else distance_m_first
|
||||
out = {
|
||||
"success": True,
|
||||
"result_img": r_img,
|
||||
"center": center, "radius": radius, "method": method,
|
||||
"best_radius1": best_radius1, "ellipse_params": ellipse_params,
|
||||
"dx": dx, "dy": dy, "distance_m": d_m,
|
||||
"laser_point": laser_point, "laser_point_method": laser_point_method,
|
||||
"offset_method": "yellow_ellipse" if ellipse_params else "yellow_circle",
|
||||
"distance_method": "yellow_radius",
|
||||
}
|
||||
if yolo_roi_xyxy is not None:
|
||||
out["yolo_roi_xyxy"] = yolo_roi_xyxy
|
||||
return out
|
||||
|
||||
if not use_tri:
|
||||
# 三角形未配置,直接跑圆形检测
|
||||
return _build_circle_result(
|
||||
detect_circle_v3(frame, laser_point, img_cv=img_cv)
|
||||
)
|
||||
|
||||
# ── Step 4: 先独占跑三角形,超时或失败后再跑圆形(不与圆心并行,避免抢 CPU)──
|
||||
roi_xyxy = None
|
||||
yolo_ring_ms = 0.0
|
||||
yolo_black_ms = 0.0
|
||||
if getattr(config, "TRIANGLE_YOLO_ROI_ENABLE", False):
|
||||
_t_yolo_ring = time_std.perf_counter()
|
||||
try:
|
||||
from target_roi_yolo import try_get_triangle_roi_from_yolo
|
||||
roi_xyxy = try_get_triangle_roi_from_yolo(
|
||||
frame, img_cv.shape[1], img_cv.shape[0], logger
|
||||
)
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.warning(f"[YOLO-ROI] {e}")
|
||||
finally:
|
||||
yolo_ring_ms = (time_std.perf_counter() - _t_yolo_ring) * 1000.0
|
||||
|
||||
_loc_mode = str(
|
||||
getattr(config, "TRIANGLE_BLACK_TRIANGLE_LOCATE_MODE", "yolo")
|
||||
).lower().strip()
|
||||
if _loc_mode not in ("yolo", "traditional"):
|
||||
_loc_mode = "yolo"
|
||||
|
||||
black_boxes_work = None
|
||||
_run_stage2_black_yolo = (
|
||||
_loc_mode == "yolo"
|
||||
and getattr(config, "TRIANGLE_BLACK_YOLO_ENABLE", False)
|
||||
and roi_xyxy is not None
|
||||
)
|
||||
if _run_stage2_black_yolo:
|
||||
_t_yolo_black = time_std.perf_counter()
|
||||
try:
|
||||
from target_roi_yolo import try_black_triangle_boxes_work
|
||||
|
||||
black_boxes_work = try_black_triangle_boxes_work(
|
||||
img_cv, roi_xyxy, logger
|
||||
)
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.warning(f"[YOLO-BLACK] {e}")
|
||||
finally:
|
||||
yolo_black_ms = (time_std.perf_counter() - _t_yolo_black) * 1000.0
|
||||
elif (
|
||||
logger
|
||||
and _loc_mode == "traditional"
|
||||
and roi_xyxy is not None
|
||||
and getattr(config, "TRIANGLE_BLACK_YOLO_ENABLE", False)
|
||||
):
|
||||
logger.info(
|
||||
"[TRI] TRIANGLE_BLACK_TRIANGLE_LOCATE_MODE=traditional:跳过 Stage2 黑三角 YOLO,"
|
||||
"仅在 Stage1 裁切内跑整幅传统三角检测"
|
||||
)
|
||||
|
||||
tri_result = {}
|
||||
|
||||
def _run_triangle():
|
||||
try:
|
||||
logger.info(f"[TRI] begin {datetime.now()}")
|
||||
logger.info(f"[TRI] K: {K}, dist: {dist_coef}, pos: {pos}, {datetime.now()}")
|
||||
_t_wall_try = time_std.perf_counter()
|
||||
tri = try_triangle_scoring(
|
||||
img_cv, (x, y), pos, K, dist_coef,
|
||||
size_range=getattr(config, "TRIANGLE_SIZE_RANGE", (8, 500)),
|
||||
roi_xyxy=roi_xyxy,
|
||||
black_yolo_boxes_work=black_boxes_work,
|
||||
yolo_ring_ms=yolo_ring_ms,
|
||||
yolo_black_ms=yolo_black_ms,
|
||||
)
|
||||
_wall_try_ms = (time_std.perf_counter() - _t_wall_try) * 1000.0
|
||||
if logger and bool(getattr(config, "TRIANGLE_LOG_E2E_TIMING", True)):
|
||||
_e2e = float(yolo_ring_ms) + float(yolo_black_ms) + float(_wall_try_ms)
|
||||
logger.info(
|
||||
f"[TRI] timing_e2e_triangle_ms={_e2e:.1f} "
|
||||
f"(yolo_ring={float(yolo_ring_ms):.1f} yolo_black={float(yolo_black_ms):.1f} "
|
||||
f"try_triangle_wall={_wall_try_ms:.1f} locate_mode={_loc_mode})"
|
||||
)
|
||||
logger.info(f"[TRI] tri: {tri}, {datetime.now()}")
|
||||
tri_result['data'] = tri
|
||||
except Exception as e:
|
||||
logger.error(f"[TRI] 三角形路径异常: {e}")
|
||||
tri_result['data'] = {'ok': False}
|
||||
|
||||
t_tri = threading.Thread(target=_run_triangle, daemon=True)
|
||||
t_tri.start()
|
||||
|
||||
tri_timeout_s = float(getattr(config, "TRIANGLE_TIMEOUT_MS", 2000)) / 1000.0
|
||||
|
||||
t_tri.join(timeout=tri_timeout_s)
|
||||
|
||||
def _tri_ok_validated(tri):
|
||||
try:
|
||||
import numpy as _np
|
||||
ok = bool(tri.get('ok'))
|
||||
if not ok:
|
||||
return False
|
||||
|
||||
dxv = tri.get("dx_cm")
|
||||
dyv = tri.get("dy_cm")
|
||||
H = tri.get("homography")
|
||||
if not _np.isfinite(dxv) or not _np.isfinite(dyv):
|
||||
logger.warning("[TRI] dx/dy 非有限值,判定为误检")
|
||||
return False
|
||||
if H is not None and not _np.all(_np.isfinite(H)):
|
||||
logger.warning("[TRI] 单应矩阵含非有限值,判定为误检")
|
||||
return False
|
||||
|
||||
# ── 检查1:单应矩阵 x/y 缩放比(靶标是正方形,H[0,0]≈H[1,1])──
|
||||
if H is not None:
|
||||
sx = abs(float(H[0, 0]))
|
||||
sy = abs(float(H[1, 1]))
|
||||
if sy > 1e-6:
|
||||
hxy_ratio = sx / sy
|
||||
# 正常拍摄比值在 0.6~1.7 之间;超出则四点严重变形,说明有误检
|
||||
if not (0.6 <= hxy_ratio <= 1.7):
|
||||
logger.warning(
|
||||
f"[TRI] 单应矩阵 sx/sy={hxy_ratio:.2f} 偏差过大,判定为误检,回退圆心"
|
||||
)
|
||||
return False
|
||||
|
||||
# ── 检查2:可选配置距离上下限(写 0 表示不启用)──────────────────
|
||||
dist_m = tri.get("distance_m")
|
||||
if dist_m is not None:
|
||||
try:
|
||||
import config as _vc
|
||||
d_min = float(getattr(_vc, "TRIANGLE_DISTANCE_MIN_M", 0.0))
|
||||
d_max = float(getattr(_vc, "TRIANGLE_DISTANCE_MAX_M", 0.0))
|
||||
except Exception:
|
||||
d_min, d_max = 0.0, 0.0
|
||||
if d_min > 0 and d_max > d_min:
|
||||
if not (d_min <= dist_m <= d_max):
|
||||
logger.warning(
|
||||
f"[TRI] 距离 {dist_m:.2f}m 超出配置范围 [{d_min},{d_max}],判定为误检,回退圆心"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return bool(tri.get('ok'))
|
||||
|
||||
def _build_tri_result(tri, yolo_roi_xyxy=None):
|
||||
out = {
|
||||
"success": True,
|
||||
"result_img": frame,
|
||||
"center": None, "radius": None,
|
||||
"method": "triangle_homography",
|
||||
"best_radius1": None, "ellipse_params": None,
|
||||
"dx": tri["dx_cm"], "dy": tri["dy_cm"],
|
||||
"distance_m": tri.get("distance_m") or distance_m_first,
|
||||
"laser_point": laser_point, "laser_point_method": laser_point_method,
|
||||
"offset_method": tri.get("offset_method") or "triangle_homography",
|
||||
"distance_method": tri.get("distance_method") or "pnp_triangle",
|
||||
"tri_markers": tri.get("markers", []),
|
||||
"tri_markers_completed": tri.get("markers_completed", []),
|
||||
"tri_homography": tri.get("homography"),
|
||||
}
|
||||
if yolo_roi_xyxy is not None:
|
||||
out["yolo_roi_xyxy"] = yolo_roi_xyxy
|
||||
return out
|
||||
|
||||
# 三角形在超时内完成
|
||||
if not t_tri.is_alive():
|
||||
tri = tri_result.get('data', {})
|
||||
if _tri_ok_validated(tri):
|
||||
logger.info(f"[TRI] end {datetime.now()} — 使用三角形结果(dx={tri['dx_cm']:.2f},dy={tri['dy_cm']:.2f}cm)")
|
||||
return _build_tri_result(tri, roi_xyxy)
|
||||
logger.info(f"[TRI] end(tri_failed, fallback circle) {datetime.now()}")
|
||||
else:
|
||||
logger.warning(f"[TRI] 超时 {tri_timeout_s:.2f}s 仍未结束,启动圆心算法(三角形仍在后台)")
|
||||
|
||||
# 三角形超时或失败 → 跑圆心;圆心跑完后再检查三角形是否已结束
|
||||
try:
|
||||
cdata = detect_circle_v3(frame, laser_point, img_cv=img_cv)
|
||||
except Exception as e:
|
||||
logger.error(f"[CIRCLE] 圆形检测异常: {e}")
|
||||
cdata = (frame, None, None, None, None, None)
|
||||
|
||||
# 圆心跑完后,若三角形恰好已经结束且结果有效,优先用三角形
|
||||
if not t_tri.is_alive():
|
||||
tri = tri_result.get('data', {})
|
||||
if _tri_ok_validated(tri):
|
||||
logger.info(f"[TRI] 圆心跑完后三角形已就绪 — 优先使用三角形结果(dx={tri['dx_cm']:.2f},dy={tri['dy_cm']:.2f}cm)")
|
||||
return _build_tri_result(tri, roi_xyxy)
|
||||
|
||||
return _build_circle_result(cdata, roi_xyxy)
|
||||
|
||||
|
||||
def process_shot(adc_val):
|
||||
"""
|
||||
处理射箭事件(逻辑控制部分)
|
||||
:param adc_val: ADC触发值
|
||||
:return: None
|
||||
"""
|
||||
logger = logger_manager.logger
|
||||
|
||||
try:
|
||||
network_manager.safe_enqueue({"shoot_event": "start"}, msg_type=2, high=True)
|
||||
frame = camera_manager.read_frame()
|
||||
|
||||
# 调用算法分析
|
||||
analysis_result = analyze_shot(frame)
|
||||
|
||||
if not analysis_result.get("success"):
|
||||
reason = analysis_result.get("reason", "unknown")
|
||||
if logger:
|
||||
logger.warning(f"[MAIN] 射箭分析失败: {reason}")
|
||||
time.sleep_ms(100)
|
||||
return
|
||||
|
||||
# 提取分析结果
|
||||
result_img = analysis_result["result_img"]
|
||||
center = analysis_result["center"]
|
||||
radius = analysis_result["radius"]
|
||||
method = analysis_result["method"]
|
||||
ellipse_params = analysis_result["ellipse_params"]
|
||||
dx = analysis_result["dx"]
|
||||
dy = analysis_result["dy"]
|
||||
distance_m = analysis_result["distance_m"]
|
||||
laser_point = analysis_result["laser_point"]
|
||||
laser_point_method = analysis_result["laser_point_method"]
|
||||
offset_method = analysis_result.get("offset_method", "yellow_circle")
|
||||
distance_method = analysis_result.get("distance_method", "yellow_radius")
|
||||
tri_markers = analysis_result.get("tri_markers", [])
|
||||
tri_markers_completed = analysis_result.get("tri_markers_completed", [])
|
||||
tri_homography = analysis_result.get("tri_homography")
|
||||
yolo_roi_xyxy = analysis_result.get("yolo_roi_xyxy")
|
||||
draw_yolo_roi = (
|
||||
yolo_roi_xyxy is not None
|
||||
and getattr(config, "TRIANGLE_YOLO_DRAW_ROI_ON_SHOT", True)
|
||||
)
|
||||
x, y = laser_point
|
||||
|
||||
# 三角形路径成功时 center/radius 为空是正常的;此时用 triangle 方法名用于保存文件名与上报字段 m
|
||||
if (not method) and tri_markers:
|
||||
method = "triangle_homography"
|
||||
|
||||
if config.SHOW_CAMERA_PHOTO_WHILE_SHOOTING:
|
||||
camera_manager.show(result_img)
|
||||
|
||||
if dx is None and dy is None and logger:
|
||||
logger.warning("[MAIN] 未检测到偏移量(三角形与圆形均失败),但会保存图像")
|
||||
|
||||
# 生成射箭ID
|
||||
from shot_id_generator import shot_id_generator
|
||||
shot_id = shot_id_generator.generate_id()
|
||||
|
||||
if logger:
|
||||
logger.info(f"[MAIN] 射箭ID: {shot_id}")
|
||||
|
||||
laser_distance_m = None
|
||||
laser_signal_quality = 0
|
||||
|
||||
# x,y 单位:物理厘米(compute_laser_position 与三角形单应性均输出物理 cm)
|
||||
# 未检测到靶心时 x/y 用 200.0(脱靶标志)
|
||||
srv_x = round(float(dx), 4) if dx is not None else 200.0
|
||||
srv_y = round(float(dy), 4) if dy is not None else 200.0
|
||||
|
||||
# 构造上报数据
|
||||
inner_data = {
|
||||
"shot_id": shot_id,
|
||||
"x": srv_x,
|
||||
"y": srv_y,
|
||||
"r": 20.0, # 保留字段(服务端当前忽略,物理外环半径 cm)
|
||||
"d": round((distance_m or 0.0) * 100),
|
||||
"d_laser": round((laser_distance_m or 0.0) * 100),
|
||||
"d_laser_quality": laser_signal_quality,
|
||||
"m": method if method else "no_target",
|
||||
"adc": adc_val,
|
||||
"laser_method": laser_point_method,
|
||||
"target_x": float(x),
|
||||
"target_y": float(y),
|
||||
"offset_method": offset_method,
|
||||
"distance_method": distance_method,
|
||||
}
|
||||
|
||||
if ellipse_params:
|
||||
(ell_center, (width, height), angle) = ellipse_params
|
||||
inner_data["ellipse_major_axis"] = float(max(width, height))
|
||||
inner_data["ellipse_minor_axis"] = float(min(width, height))
|
||||
inner_data["ellipse_angle"] = float(angle)
|
||||
inner_data["ellipse_center_x"] = float(ell_center[0])
|
||||
inner_data["ellipse_center_y"] = float(ell_center[1])
|
||||
else:
|
||||
inner_data["ellipse_major_axis"] = None
|
||||
inner_data["ellipse_minor_axis"] = None
|
||||
inner_data["ellipse_angle"] = None
|
||||
inner_data["ellipse_center_x"] = None
|
||||
inner_data["ellipse_center_y"] = None
|
||||
|
||||
report_data = {"cmd": 1, "data": inner_data}
|
||||
network_manager.safe_enqueue(report_data, msg_type=2, high=True)
|
||||
|
||||
# 数据上报后再画标注,不干扰检测阶段的原始画面
|
||||
if result_img is not None:
|
||||
# 1. 若有三角形标记,先用 cv2 画轮廓 / 顶点 / ID,再反推靶心位置
|
||||
if tri_markers:
|
||||
import cv2 as _cv2
|
||||
import numpy as _np
|
||||
_img_cv = image.image2cv(result_img, False, False)
|
||||
|
||||
# YOLO 靶环框在 vision.enqueue_save_shot 的 worker 里绘制,避免阻塞主流程
|
||||
|
||||
# 三角形轮廓 + 直角顶点 + ID
|
||||
for _m in tri_markers:
|
||||
_corners = _np.array(_m["corners"], dtype=_np.int32)
|
||||
_cv2.polylines(_img_cv, [_corners], True, (0, 255, 0), 2)
|
||||
_cx, _cy = int(_m["center"][0]), int(_m["center"][1])
|
||||
_cv2.circle(_img_cv, (_cx, _cy), 4, (0, 0, 255), -1)
|
||||
_cv2.putText(_img_cv, f"T{_m['id']}",
|
||||
(_cx - 18, _cy - 12),
|
||||
_cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 0), 1)
|
||||
|
||||
# 3点补全的虚拟角点:只画中心点 + 文本,避免误认为真实检测到的三角形
|
||||
try:
|
||||
if tri_markers_completed:
|
||||
for _m in tri_markers_completed:
|
||||
if not _m.get("is_virtual"):
|
||||
continue
|
||||
_cx, _cy = int(_m["center"][0]), int(_m["center"][1])
|
||||
_cv2.circle(_img_cv, (_cx, _cy), 6, (255, 0, 255), 2) # 紫色空心圈
|
||||
_cv2.putText(
|
||||
_img_cv,
|
||||
f"VT{_m['id']}",
|
||||
(_cx - 22, _cy - 12),
|
||||
_cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.55,
|
||||
(255, 0, 255),
|
||||
1,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 靶心(H_inv @ [0,0]):小红圆
|
||||
_center_px = None
|
||||
if tri_homography is not None:
|
||||
try:
|
||||
_H_inv = _np.linalg.inv(tri_homography)
|
||||
_c_img = _cv2.perspectiveTransform(
|
||||
_np.array([[[0.0, 0.0]]], dtype=_np.float32), _H_inv)[0][0]
|
||||
_ocx, _ocy = int(_c_img[0]), int(_c_img[1])
|
||||
_cv2.circle(_img_cv, (_ocx, _ocy), 5, (0, 0, 255), -1) # 实心
|
||||
_cv2.circle(_img_cv, (_ocx, _ocy), 9, (0, 0, 255), 1) # 外框
|
||||
_center_px = (_ocx, _ocy)
|
||||
logger.info(f"[算法] 靶心: {_center_px}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 叠加信息:落点-圆心距离 / 相机-靶距离等
|
||||
try:
|
||||
import math as _math
|
||||
_lines = []
|
||||
if dx is not None and dy is not None:
|
||||
_r_cm = _math.hypot(float(dx), float(dy))
|
||||
_lines.append(f"offset=({float(dx):.2f},{float(dy):.2f})cm |r|={_r_cm:.2f}cm")
|
||||
if distance_m is not None:
|
||||
_lines.append(f"cam_dist={float(distance_m):.2f}m ({distance_method})")
|
||||
if method:
|
||||
_lines.append(f"method={method}")
|
||||
if _lines:
|
||||
_y0 = 22
|
||||
for i, _t in enumerate(_lines):
|
||||
_cv2.putText(
|
||||
_img_cv,
|
||||
_t,
|
||||
(10, _y0 + i * 18),
|
||||
_cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 255, 0),
|
||||
1,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
result_img = image.cv2image(_img_cv, False, False)
|
||||
|
||||
elif draw_yolo_roi:
|
||||
# 仅 YOLO 标注时也不在主线程画框,交给存图 worker
|
||||
pass
|
||||
|
||||
# 2. 激光十字线
|
||||
_lc = image.Color(config.LASER_COLOR[0], config.LASER_COLOR[1], config.LASER_COLOR[2])
|
||||
result_img.draw_line(int(x - config.LASER_LENGTH), int(y),
|
||||
int(x + config.LASER_LENGTH), int(y),
|
||||
_lc, config.LASER_THICKNESS)
|
||||
result_img.draw_line(int(x), int(y - config.LASER_LENGTH),
|
||||
int(x), int(y + config.LASER_LENGTH),
|
||||
_lc, config.LASER_THICKNESS)
|
||||
result_img.draw_circle(int(x), int(y), 1, _lc, config.LASER_THICKNESS)
|
||||
|
||||
# 闪一下激光(射箭反馈)
|
||||
if config.FLASH_LASER_WHILE_SHOOTING:
|
||||
laser_manager.flash_laser(config.FLASH_LASER_DURATION_MS)
|
||||
|
||||
# 保存图像(异步队列,与 main.py 一致)
|
||||
enqueue_save_shot(
|
||||
result_img,
|
||||
center,
|
||||
radius,
|
||||
method,
|
||||
ellipse_params,
|
||||
(x, y),
|
||||
distance_m,
|
||||
shot_id=shot_id,
|
||||
photo_dir=config.PHOTO_DIR if config.SAVE_IMAGE_ENABLED else None,
|
||||
yolo_roi_xyxy=yolo_roi_xyxy if draw_yolo_roi else None,
|
||||
)
|
||||
|
||||
if logger:
|
||||
if dx is not None and dy is not None:
|
||||
logger.info(f"射箭事件已加入发送队列(偏移=({dx:.2f},{dy:.2f})cm),ID: {shot_id}")
|
||||
else:
|
||||
logger.info(f"射箭事件已加入发送队列(未检测到偏移,已保存图像),ID: {shot_id}")
|
||||
|
||||
time.sleep_ms(100)
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.error(f"[MAIN] 图像处理异常: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
time.sleep_ms(100)
|
||||
76
shot_id_generator.py
Normal file
76
shot_id_generator.py
Normal file
@@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
射箭ID生成器
|
||||
为每次射箭生成唯一ID,格式:{timestamp_ms}_{counter}
|
||||
"""
|
||||
from maix import time
|
||||
import threading
|
||||
|
||||
|
||||
class ShotIDGenerator:
|
||||
"""射箭ID生成器(单例)"""
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ShotIDGenerator, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._counter = 0
|
||||
self._last_timestamp_ms = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._initialized = True
|
||||
|
||||
def generate_id(self, device_id=None):
|
||||
"""
|
||||
生成唯一的射箭ID
|
||||
|
||||
Args:
|
||||
device_id: 可选的设备ID,如果提供则包含在ID中(格式:{device_id}_{timestamp_ms}_{counter})
|
||||
如果不提供,则使用简单格式(格式:{timestamp_ms}_{counter})
|
||||
|
||||
Returns:
|
||||
str: 唯一的射箭ID
|
||||
"""
|
||||
with self._lock:
|
||||
current_timestamp_ms = time.ticks_ms()
|
||||
|
||||
# 如果时间戳相同,增加计数器;否则重置计数器
|
||||
if current_timestamp_ms == self._last_timestamp_ms:
|
||||
self._counter += 1
|
||||
else:
|
||||
self._counter = 0
|
||||
self._last_timestamp_ms = current_timestamp_ms
|
||||
|
||||
# 生成ID
|
||||
if device_id:
|
||||
shot_id = f"{device_id}_{current_timestamp_ms}_{self._counter}"
|
||||
else:
|
||||
shot_id = f"{current_timestamp_ms}_{self._counter}"
|
||||
|
||||
return shot_id
|
||||
|
||||
def reset(self):
|
||||
"""重置计数器(通常不需要调用)"""
|
||||
with self._lock:
|
||||
self._counter = 0
|
||||
self._last_timestamp_ms = 0
|
||||
|
||||
|
||||
# 创建全局单例实例
|
||||
shot_id_generator = ShotIDGenerator()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
668
target_roi_yolo.py
Normal file
668
target_roi_yolo.py
Normal file
@@ -0,0 +1,668 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
MaixCAM NPU YOLOv5:先检靶环/整靶区域并裁切 ROI;黑三角 Stage2 在裁切图上推理(与训练一致),
|
||||
再在各子框上跑传统直角点算法。
|
||||
|
||||
- 相机全分辨率(如 640×480)与模型输入(如 320×320)不一致时,需把检测框从
|
||||
「网络输入坐标系」映回全图,或直接使用 Maix 已映射到源图坐标的模式(见 config)。
|
||||
|
||||
依赖:maix.nn.YOLOv5;靶环模型 config.TRIANGLE_YOLO_MODEL_PATH;黑三角模型
|
||||
config.TRIANGLE_BLACK_YOLO_MODEL_PATH(可多实例缓存,按路径区分)。
|
||||
|
||||
224×224、320×320 等「网络输入尺寸」由导出的 .mud 决定,运行时打印为 net_in=,无需在业务 config 里写死。
|
||||
|
||||
返回 (x0, y0, x1, y1) 为整幅 img_cv 上的轴对齐矩形,半开区间按三角形裁剪习惯:
|
||||
实际裁剪为 img[y0:y1, x0:x1]。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _stage2_roi_crop_save_worker(
|
||||
slab_rgb,
|
||||
out_local_boxes,
|
||||
rx0,
|
||||
ry0,
|
||||
rw,
|
||||
rh,
|
||||
base_dir,
|
||||
draw_boxes,
|
||||
jpeg_quality,
|
||||
roi_max_images,
|
||||
logger_ref,
|
||||
):
|
||||
"""后台写 Stage2 裁切 JPEG,避免阻塞 NPU 后续流程。"""
|
||||
try:
|
||||
import time
|
||||
|
||||
import cv2
|
||||
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
fn = os.path.join(
|
||||
base_dir,
|
||||
f"stage2_roi_{rx0}_{ry0}_{rw}x{rh}_{int(time.time() * 1000)}.jpg",
|
||||
)
|
||||
bgr = cv2.cvtColor(slab_rgb, cv2.COLOR_RGB2BGR)
|
||||
if draw_boxes and out_local_boxes:
|
||||
for i, (bx0, by0, bx1, by1) in enumerate(out_local_boxes):
|
||||
x0, y0 = int(bx0), int(by0)
|
||||
x1, y1 = int(bx1) - 1, int(by1) - 1
|
||||
x1 = max(x0, min(x1, rw - 1))
|
||||
y1 = max(y0, min(y1, rh - 1))
|
||||
cv2.rectangle(bgr, (x0, y0), (x1, y1), (0, 255, 0), 2)
|
||||
cv2.putText(
|
||||
bgr,
|
||||
f"s2_{i}",
|
||||
(x0, max(0, y0 - 4)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 255, 0),
|
||||
1,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
cv2.imwrite(fn, bgr, [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_quality)])
|
||||
try:
|
||||
from vision import prune_old_images_in_dir
|
||||
|
||||
prune_old_images_in_dir(
|
||||
base_dir, roi_max_images, logger_ref, "[YOLO-BLACK]"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
if logger_ref:
|
||||
extra = (
|
||||
f",已绘 Stage2 框×{len(out_local_boxes)}"
|
||||
if (draw_boxes and out_local_boxes)
|
||||
else ""
|
||||
)
|
||||
logger_ref.info(f"[YOLO-BLACK] 已保存 Stage1 裁切图(异步): {fn}{extra}")
|
||||
except Exception as e:
|
||||
if logger_ref:
|
||||
logger_ref.warning(f"[YOLO-BLACK] 异步保存裁切图失败: {e}")
|
||||
|
||||
_detector_by_path = {}
|
||||
|
||||
|
||||
def reset_yolo_detector_cache():
|
||||
"""切换模型路径时可调用(通常不必)。"""
|
||||
global _detector_by_path
|
||||
_detector_by_path.clear()
|
||||
|
||||
|
||||
def _get_detector(model_path: str):
|
||||
global _detector_by_path
|
||||
if not model_path or not os.path.isfile(model_path):
|
||||
return None
|
||||
if model_path in _detector_by_path:
|
||||
return _detector_by_path[model_path]
|
||||
try:
|
||||
from maix import nn
|
||||
except ImportError:
|
||||
return None
|
||||
_detector_by_path[model_path] = nn.YOLOv5(model=model_path, dual_buff=False)
|
||||
return _detector_by_path[model_path]
|
||||
|
||||
|
||||
def preload_yolo_detector(logger=None):
|
||||
"""
|
||||
启动阶段预加载 YOLO detector,避免第一次真实射箭承担模型加载开销。
|
||||
detect 使用 dual_buff=False,不再需要用首帧 warmup 抵消双缓冲的一帧延迟。
|
||||
"""
|
||||
try:
|
||||
import config as cfg
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.warning(f"[YOLO-ROI] 预加载失败:无法读取 config: {e}")
|
||||
return False
|
||||
|
||||
ok = False
|
||||
|
||||
if bool(getattr(cfg, "TRIANGLE_YOLO_ROI_ENABLE", False)):
|
||||
model_path = getattr(cfg, "TRIANGLE_YOLO_MODEL_PATH", "") or ""
|
||||
det = _get_detector(model_path)
|
||||
if det is None:
|
||||
if logger:
|
||||
logger.warning(f"[YOLO-ROI] 预加载失败:无法加载模型 {model_path}")
|
||||
else:
|
||||
ok = True
|
||||
try:
|
||||
net_w = int(det.input_width())
|
||||
net_h = int(det.input_height())
|
||||
except Exception:
|
||||
net_w = net_h = -1
|
||||
if logger:
|
||||
logger.info(
|
||||
f"[YOLO-ROI] 靶环模型已预加载: {model_path}, net_in={net_w}×{net_h}"
|
||||
)
|
||||
|
||||
_loc_black = str(
|
||||
getattr(cfg, "TRIANGLE_BLACK_TRIANGLE_LOCATE_MODE", "yolo")
|
||||
).lower().strip()
|
||||
if _loc_black not in ("yolo", "traditional"):
|
||||
_loc_black = "yolo"
|
||||
_preload_black = (
|
||||
bool(getattr(cfg, "TRIANGLE_BLACK_YOLO_ENABLE", False))
|
||||
and _loc_black == "yolo"
|
||||
and bool(getattr(cfg, "TRIANGLE_BLACK_YOLO_PRELOAD_ON_BOOT", True))
|
||||
)
|
||||
if _preload_black:
|
||||
bp = getattr(cfg, "TRIANGLE_BLACK_YOLO_MODEL_PATH", "") or ""
|
||||
d2 = _get_detector(bp)
|
||||
if d2 is None:
|
||||
if logger:
|
||||
logger.warning(f"[YOLO-BLACK] 预加载失败:无法加载模型 {bp}")
|
||||
else:
|
||||
ok = True
|
||||
try:
|
||||
nw2 = int(d2.input_width())
|
||||
nh2 = int(d2.input_height())
|
||||
except Exception:
|
||||
nw2 = nh2 = -1
|
||||
if logger:
|
||||
logger.info(
|
||||
f"[YOLO-BLACK] 黑三角模型已预加载: {bp}, net_in={nw2}×{nh2}"
|
||||
)
|
||||
elif logger and bool(getattr(cfg, "TRIANGLE_BLACK_YOLO_ENABLE", False)):
|
||||
if _loc_black != "yolo":
|
||||
logger.info(
|
||||
"[YOLO-BLACK] TRIANGLE_BLACK_TRIANGLE_LOCATE_MODE=%s:跳过黑三角模型预加载"
|
||||
% (_loc_black,)
|
||||
)
|
||||
|
||||
return ok
|
||||
|
||||
|
||||
def _letterbox_net_to_src_xyxy(
|
||||
x: float, y: float, w: float, h: float,
|
||||
src_w: int, src_h: int, net_w: int, net_h: int,
|
||||
):
|
||||
"""
|
||||
检测框在网络输入图上(含 letterbox 填充),映回到 src_w×src_h 原图。
|
||||
x,y,w,h 为网络坐标系下的左上角与宽高。
|
||||
"""
|
||||
scale = min(net_w / float(src_w), net_h / float(src_h))
|
||||
nw = src_w * scale
|
||||
nh = src_h * scale
|
||||
pad_x = (net_w - nw) * 0.5
|
||||
pad_y = (net_h - nh) * 0.5
|
||||
x0 = (x - pad_x) / scale
|
||||
y0 = (y - pad_y) / scale
|
||||
x1 = (x + w - pad_x) / scale
|
||||
y1 = (y + h - pad_y) / scale
|
||||
return x0, y0, x1, y1
|
||||
|
||||
|
||||
def _det_obj_class_id(o):
|
||||
"""Maix / 不同版本可能用 class_id、cls、label 等字段。"""
|
||||
for key in ("class_id", "cls", "label", "category", "cat_id", "id"):
|
||||
if hasattr(o, key):
|
||||
v = getattr(o, key)
|
||||
if v is None:
|
||||
continue
|
||||
try:
|
||||
return int(float(v))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _det_obj_from_seq(t):
|
||||
"""若 detect 返回 list/tuple:[x,y,w,h,score,cls](Maix 常用 xywh),包装成属性对象。"""
|
||||
if not isinstance(t, (list, tuple)) or len(t) < 6:
|
||||
return None
|
||||
|
||||
class _Box:
|
||||
__slots__ = ("x", "y", "w", "h", "score", "class_id")
|
||||
|
||||
b = _Box()
|
||||
b.x = float(t[0])
|
||||
b.y = float(t[1])
|
||||
b.w = float(t[2])
|
||||
b.h = float(t[3])
|
||||
b.score = float(t[4])
|
||||
b.class_id = int(float(t[5]))
|
||||
return b
|
||||
|
||||
|
||||
def _normalize_objs(objs):
|
||||
out = []
|
||||
for o in objs or []:
|
||||
if isinstance(o, (list, tuple)):
|
||||
m = _det_obj_from_seq(o)
|
||||
if m is not None:
|
||||
out.append(m)
|
||||
else:
|
||||
out.append(o)
|
||||
return out
|
||||
|
||||
|
||||
def _det_to_src_xyxy(o, coord_mode: str, src_w: int, src_h: int, net_w: int, net_h: int):
|
||||
"""把单个检测框转为全图坐标系下的 xyxy(半开区间语义与后续 clip 一致)。"""
|
||||
x, y, w, h = float(o.x), float(o.y), float(o.w), float(o.h)
|
||||
if coord_mode in ("native", "source", "camera", "full"):
|
||||
return x, y, x + w, y + h
|
||||
return _letterbox_net_to_src_xyxy(x, y, w, h, src_w, src_h, net_w, net_h)
|
||||
|
||||
|
||||
def _merge_roi_xyxy(xy_list, merge_mode: str):
|
||||
"""
|
||||
merge_mode:
|
||||
union — 所有框的外接矩形(适合「整靶+多角标」同属一类、多框场景)
|
||||
largest — 取面积最大的单个框(适合只有一个大框代表整靶)
|
||||
"""
|
||||
if not xy_list:
|
||||
return None
|
||||
if merge_mode in ("union", "merge", "all"):
|
||||
x0 = min(a[0] for a in xy_list)
|
||||
y0 = min(a[1] for a in xy_list)
|
||||
x1 = max(a[2] for a in xy_list)
|
||||
y1 = max(a[3] for a in xy_list)
|
||||
return x0, y0, x1, y1
|
||||
# largest
|
||||
def _area(t):
|
||||
return max(0.0, t[2] - t[0]) * max(0.0, t[3] - t[1])
|
||||
|
||||
best = max(xy_list, key=_area)
|
||||
return best[0], best[1], best[2], best[3]
|
||||
|
||||
|
||||
def _roi_aspect_sane(x0, y0, x1, y1, src_w: int, src_h: int) -> bool:
|
||||
"""过滤 letterbox 重复映射等导致的扁条/细条 ROI。"""
|
||||
bw = x1 - x0
|
||||
bh = y1 - y0
|
||||
if bw < 8 or bh < 8:
|
||||
return False
|
||||
area_frac = (bw * bh) / float(max(1, src_w * src_h))
|
||||
if area_frac < 0.015: # 小于全图约 1.5% 认为不可信
|
||||
return False
|
||||
ar = bw / max(bh, 1e-6)
|
||||
if ar > 5.5 or ar < 1.0 / 5.5:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _expand_xyxy(x0, y0, x1, y1, src_w, src_h, margin_frac: float):
|
||||
bw = max(x1 - x0, 1e-6)
|
||||
bh = max(y1 - y0, 1e-6)
|
||||
mx = bw * margin_frac
|
||||
my = bh * margin_frac
|
||||
x0 -= mx
|
||||
y0 -= my
|
||||
x1 += mx
|
||||
y1 += my
|
||||
x0 = max(0, min(int(round(x0)), src_w - 1))
|
||||
y0 = max(0, min(int(round(y0)), src_h - 1))
|
||||
x1 = max(x0 + 1, min(int(round(x1)), src_w))
|
||||
y1 = max(y0 + 1, min(int(round(y1)), src_h))
|
||||
return x0, y0, x1, y1
|
||||
|
||||
|
||||
def try_get_triangle_roi_from_yolo(maix_frame, src_w: int, src_h: int, logger=None):
|
||||
"""
|
||||
用 YOLO 在 maix_frame 上检测靶环类,返回整图上的裁剪框 (x0,y0,x1,y1);失败返回 None。
|
||||
|
||||
:param maix_frame: camera.read() 返回的 Maix 图像(与 nn.YOLOv5.detect 一致)
|
||||
:param src_w, src_h: 与 img_cv / 标定一致的分辨率(通常与 camera 一致)
|
||||
"""
|
||||
try:
|
||||
import config as cfg
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not bool(getattr(cfg, "TRIANGLE_YOLO_ROI_ENABLE", False)):
|
||||
return None
|
||||
|
||||
model_path = getattr(cfg, "TRIANGLE_YOLO_MODEL_PATH", "") or ""
|
||||
if not os.path.isfile(model_path):
|
||||
if logger:
|
||||
logger.warning(f"[YOLO-ROI] 模型文件不存在: {model_path}")
|
||||
return None
|
||||
|
||||
det = _get_detector(model_path)
|
||||
if det is None:
|
||||
if logger:
|
||||
logger.warning("[YOLO-ROI] 无法加载 nn.YOLOv5(非 Maix 环境或导入失败)")
|
||||
return None
|
||||
|
||||
conf_th = float(getattr(cfg, "TRIANGLE_YOLO_CONF_TH", 0.5))
|
||||
iou_th = float(getattr(cfg, "TRIANGLE_YOLO_IOU_TH", 0.45))
|
||||
class_ids = getattr(cfg, "TRIANGLE_YOLO_RING_CLASS_IDS", (0,))
|
||||
if isinstance(class_ids, int):
|
||||
class_ids = (class_ids,)
|
||||
margin_frac = float(getattr(cfg, "TRIANGLE_YOLO_ROI_MARGIN_FRAC", 0.12))
|
||||
coord_mode = str(getattr(cfg, "TRIANGLE_YOLO_COORD_MODE", "native")).lower()
|
||||
merge_mode = str(getattr(cfg, "TRIANGLE_YOLO_ROI_MERGE_MODE", "union")).lower()
|
||||
reject_bad = bool(getattr(cfg, "TRIANGLE_YOLO_REJECT_BAD_ROI", True))
|
||||
|
||||
try:
|
||||
raw = det.detect(maix_frame, conf_th=conf_th, iou_th=iou_th)
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.warning(f"[YOLO-ROI] detect 异常: {e}")
|
||||
return None
|
||||
|
||||
objs = _normalize_objs(raw if raw is not None else [])
|
||||
|
||||
candidates = []
|
||||
for o in objs:
|
||||
cid = _det_obj_class_id(o)
|
||||
if cid is not None and cid in class_ids:
|
||||
candidates.append(o)
|
||||
|
||||
if not candidates and bool(getattr(cfg, "TRIANGLE_YOLO_RETRY_ON_EMPTY", False)):
|
||||
retry_conf = float(getattr(cfg, "TRIANGLE_YOLO_RETRY_CONF_TH", conf_th))
|
||||
if retry_conf > 0 and retry_conf < conf_th:
|
||||
try:
|
||||
raw_retry = det.detect(maix_frame, conf_th=retry_conf, iou_th=iou_th)
|
||||
objs_retry = _normalize_objs(raw_retry if raw_retry is not None else [])
|
||||
candidates_retry = []
|
||||
for o in objs_retry:
|
||||
cid = _det_obj_class_id(o)
|
||||
if cid is not None and cid in class_ids:
|
||||
candidates_retry.append(o)
|
||||
if candidates_retry:
|
||||
if logger:
|
||||
logger.info(
|
||||
f"[YOLO-ROI] conf={conf_th} 下 0 候选,"
|
||||
f"用 retry_conf={retry_conf} 重试得到 {len(candidates_retry)} 个候选"
|
||||
)
|
||||
objs = objs_retry
|
||||
candidates = candidates_retry
|
||||
conf_th = retry_conf
|
||||
elif logger:
|
||||
logger.info(
|
||||
f"[YOLO-ROI] conf={conf_th} 下 0 候选;"
|
||||
f"retry_conf={retry_conf} 仍为 0 候选"
|
||||
)
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.warning(f"[YOLO-ROI] 低阈值重试异常: {e}")
|
||||
|
||||
if not candidates:
|
||||
if logger:
|
||||
n = len(objs)
|
||||
if n == 0:
|
||||
logger.info(
|
||||
f"[YOLO-ROI] detect 返回 0 个框(conf≥{conf_th})。"
|
||||
f"可尝试 config 里降低 TRIANGLE_YOLO_CONF_TH(如 0.25~0.35),"
|
||||
f"或确认射箭帧与训练图光照/构图接近。"
|
||||
)
|
||||
else:
|
||||
seen = []
|
||||
for o in objs[:8]:
|
||||
cid = _det_obj_class_id(o)
|
||||
sc = getattr(o, "score", None)
|
||||
try:
|
||||
sc_f = float(sc) if sc is not None else None
|
||||
except Exception:
|
||||
sc_f = None
|
||||
seen.append(f"cls={cid},score={sc_f}")
|
||||
logger.info(
|
||||
f"[YOLO-ROI] 有 {n} 个框但类别不在 {class_ids} 内;"
|
||||
f"前几条: {seen}。请核对 TRIANGLE_YOLO_RING_CLASS_IDS,"
|
||||
f"或查看 Maix 文档中检测结果的类别字段名。"
|
||||
)
|
||||
return None
|
||||
|
||||
net_w = int(det.input_width())
|
||||
net_h = int(det.input_height())
|
||||
|
||||
min_side = float(getattr(cfg, "TRIANGLE_YOLO_MIN_BOX_SIDE_PX", 8.0))
|
||||
xy_list = []
|
||||
for o in candidates:
|
||||
x0n, y0n, x1n, y1n = _det_to_src_xyxy(o, coord_mode, src_w, src_h, net_w, net_h)
|
||||
bw, bh = x1n - x0n, y1n - y0n
|
||||
if bw >= min_side and bh >= min_side:
|
||||
xy_list.append((x0n, y0n, x1n, y1n))
|
||||
|
||||
if not xy_list:
|
||||
if logger:
|
||||
logger.info(
|
||||
f"[YOLO-ROI] {len(candidates)} 个候选经 min_side={min_side} 过滤后为空,放弃 ROI"
|
||||
)
|
||||
return None
|
||||
|
||||
merged = _merge_roi_xyxy(xy_list, merge_mode)
|
||||
if merged is None:
|
||||
return None
|
||||
x0, y0, x1, y1 = merged
|
||||
|
||||
# clip 到画布(合并前框可能略越界)
|
||||
x0 = max(0, min(x0, src_w - 1))
|
||||
y0 = max(0, min(y0, src_h - 1))
|
||||
x1 = max(x0 + 1, min(x1, src_w))
|
||||
y1 = max(y0 + 1, min(y1, src_h))
|
||||
|
||||
x0, y0, x1, y1 = _expand_xyxy(x0, y0, x1, y1, src_w, src_h, margin_frac)
|
||||
|
||||
if reject_bad and not _roi_aspect_sane(x0, y0, x1, y1, src_w, src_h):
|
||||
if logger:
|
||||
logger.warning(
|
||||
f"[YOLO-ROI] 裁剪框异常(过小或过扁)mode={coord_mode} merge={merge_mode} "
|
||||
f"→ [{x0},{y0},{x1},{y1}],放弃 ROI、三角形改用整图。"
|
||||
f"若持续出现可尝试 coord_mode=letterbox/native 切换。"
|
||||
)
|
||||
return None
|
||||
|
||||
if logger:
|
||||
nbox = len(candidates)
|
||||
logger.info(
|
||||
f"[YOLO-ROI] boxes={nbox} merge={merge_mode} coord={coord_mode} "
|
||||
f"net_in={net_w}×{net_h}(来自模型) → crop=[{x0},{y0},{x1},{y1}] "
|
||||
f"({x1-x0}×{y1-y0}px)"
|
||||
)
|
||||
|
||||
return (x0, y0, x1, y1)
|
||||
|
||||
|
||||
def _expand_xyxy_local(x0, y0, x1, y1, w_lim, h_lim, margin_frac: float):
|
||||
"""在宽 w_lim、高 h_lim 的局部坐标系内扩展框。"""
|
||||
bw = max(x1 - x0, 1e-6)
|
||||
bh = max(y1 - y0, 1e-6)
|
||||
mx = bw * margin_frac
|
||||
my = bh * margin_frac
|
||||
x0 -= mx
|
||||
y0 -= my
|
||||
x1 += mx
|
||||
y1 += my
|
||||
x0 = max(0, min(int(round(x0)), w_lim - 1))
|
||||
y0 = max(0, min(int(round(y0)), h_lim - 1))
|
||||
x1 = max(x0 + 1, min(int(round(x1)), w_lim))
|
||||
y1 = max(y0 + 1, min(int(round(y1)), h_lim))
|
||||
return x0, y0, x1, y1
|
||||
|
||||
|
||||
def try_black_triangle_boxes_work(img_rgb, ring_roi_xyxy, logger=None):
|
||||
"""
|
||||
Stage2:在 **Stage1 靶环 ROI 裁切图** 上跑黑三角 YOLO(与训练时 stage2 构图一致),
|
||||
检测框坐标已落在 **靶环裁切图**(与 try_triangle_scoring 中 img_work)同一坐标系,
|
||||
返回 (x0,y0,x1,y1) 整数元组列表。
|
||||
|
||||
img_rgb: 与 try_triangle_scoring 相同的全图 RGB(numpy,H×W×3)。
|
||||
ring_roi_xyxy: 全图上的 (rx0, ry0, rx1, ry1),与 try_get_triangle_roi_from_yolo 一致。
|
||||
"""
|
||||
if ring_roi_xyxy is None:
|
||||
return []
|
||||
if img_rgb is None or getattr(img_rgb, "size", 0) == 0:
|
||||
return []
|
||||
try:
|
||||
import config as cfg
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
if not bool(getattr(cfg, "TRIANGLE_BLACK_YOLO_ENABLE", False)):
|
||||
return []
|
||||
|
||||
model_path = getattr(cfg, "TRIANGLE_BLACK_YOLO_MODEL_PATH", "") or ""
|
||||
if not os.path.isfile(model_path):
|
||||
if logger:
|
||||
logger.warning(f"[YOLO-BLACK] 模型文件不存在: {model_path}")
|
||||
return []
|
||||
|
||||
det = _get_detector(model_path)
|
||||
if det is None:
|
||||
if logger:
|
||||
logger.warning("[YOLO-BLACK] 无法加载 nn.YOLOv5")
|
||||
return []
|
||||
|
||||
conf_th = float(getattr(cfg, "TRIANGLE_BLACK_YOLO_CONF_TH", 0.5))
|
||||
iou_th = float(getattr(cfg, "TRIANGLE_BLACK_YOLO_IOU_TH", 0.45))
|
||||
class_ids = getattr(cfg, "TRIANGLE_BLACK_YOLO_CLASS_IDS", (0,))
|
||||
if isinstance(class_ids, int):
|
||||
class_ids = (class_ids,)
|
||||
coord_mode = str(getattr(cfg, "TRIANGLE_BLACK_YOLO_COORD_MODE", "native")).lower()
|
||||
margin_frac = float(getattr(cfg, "TRIANGLE_BLACK_YOLO_BOX_MARGIN_FRAC", 0.08))
|
||||
min_side = float(getattr(cfg, "TRIANGLE_BLACK_YOLO_MIN_BOX_SIDE_PX", 6.0))
|
||||
crop_min = int(getattr(cfg, "TRIANGLE_CROP_ROI_MIN_SIDE_PX", 64))
|
||||
|
||||
h_full, w_full = int(img_rgb.shape[0]), int(img_rgb.shape[1])
|
||||
rx0, ry0, rx1, ry1 = [int(round(float(v))) for v in ring_roi_xyxy]
|
||||
rx0 = max(0, min(rx0, w_full - 1))
|
||||
ry0 = max(0, min(ry0, h_full - 1))
|
||||
rx1 = max(rx0 + 1, min(rx1, w_full))
|
||||
ry1 = max(ry0 + 1, min(ry1, h_full))
|
||||
rw, rh = rx1 - rx0, ry1 - ry0
|
||||
|
||||
if rw < crop_min or rh < crop_min:
|
||||
if logger:
|
||||
logger.warning(
|
||||
f"[YOLO-BLACK] Stage1 ROI 过小 {rw}×{rh} < {crop_min},跳过黑三角检测"
|
||||
)
|
||||
return []
|
||||
|
||||
# 必须与相机帧缓冲区脱钩:切片常为非连续视图,直接喂 cv2image/NPU 易 SIGSEGV
|
||||
slab = np.ascontiguousarray(
|
||||
img_rgb[ry0:ry1, rx0:rx1], dtype=np.uint8
|
||||
).copy()
|
||||
if slab.size == 0:
|
||||
return []
|
||||
|
||||
_save_roi = bool(getattr(cfg, "TRIANGLE_BLACK_YOLO_SAVE_ROI_CROP", False))
|
||||
|
||||
try:
|
||||
from maix import image as maix_image
|
||||
|
||||
# copy=True:零拷贝时 detect 内 OpenCV 可能对底层 Mat release 触发 !fixedSize() 断言。
|
||||
roi_maix = maix_image.cv2image(slab, False, True)
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.warning(f"[YOLO-BLACK] 裁切图转 Maix image 失败: {e}")
|
||||
return []
|
||||
|
||||
try:
|
||||
raw = det.detect(roi_maix, conf_th=conf_th, iou_th=iou_th)
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.warning(f"[YOLO-BLACK] detect 异常: {e}")
|
||||
return []
|
||||
|
||||
objs = _normalize_objs(raw if raw is not None else [])
|
||||
net_w = int(det.input_width())
|
||||
net_h = int(det.input_height())
|
||||
|
||||
n_raw = len(objs)
|
||||
n_cls_ok = 0
|
||||
n_too_small = 0
|
||||
|
||||
out_local = []
|
||||
for o in objs:
|
||||
cid = _det_obj_class_id(o)
|
||||
if cid is None or cid not in class_ids:
|
||||
continue
|
||||
n_cls_ok += 1
|
||||
x0f, y0f, x1f, y1f = _det_to_src_xyxy(o, coord_mode, rw, rh, net_w, net_h)
|
||||
lx0 = max(0, min(float(x0f), rw - 1))
|
||||
ly0 = max(0, min(float(y0f), rh - 1))
|
||||
lx1 = max(lx0 + 1, min(float(x1f), rw))
|
||||
ly1 = max(ly0 + 1, min(float(y1f), rh))
|
||||
lx0, ly0, lx1, ly1 = int(round(lx0)), int(round(ly0)), int(round(lx1)), int(round(ly1))
|
||||
if (lx1 - lx0) < min_side or (ly1 - ly0) < min_side:
|
||||
n_too_small += 1
|
||||
continue
|
||||
lx0, ly0, lx1, ly1 = _expand_xyxy_local(
|
||||
lx0, ly0, lx1, ly1, rw, rh, margin_frac
|
||||
)
|
||||
out_local.append((lx0, ly0, lx1, ly1))
|
||||
|
||||
out_local.sort(key=lambda t: ((t[1] + t[3]) * 0.5, (t[0] + t[2]) * 0.5))
|
||||
|
||||
if logger and bool(
|
||||
getattr(cfg, "TRIANGLE_BLACK_YOLO_LOG_EACH_SHOT", True)
|
||||
):
|
||||
msg = (
|
||||
f"[YOLO-BLACK] Stage1裁切{rw}×{rh}上推理: raw={n_raw} 类∈{class_ids}→{n_cls_ok} "
|
||||
f"过小丢弃→{n_too_small} 最终子框={len(out_local)} "
|
||||
f"(conf={conf_th}, coord={coord_mode}, net={net_w}×{net_h}, "
|
||||
f"ring全图=[{rx0},{ry0},{rx1},{ry1}])"
|
||||
)
|
||||
logger.info(msg)
|
||||
if n_raw > 0 and n_cls_ok == 0:
|
||||
seen = []
|
||||
for o in objs[:8]:
|
||||
cid = _det_obj_class_id(o)
|
||||
sc = getattr(o, "score", None)
|
||||
try:
|
||||
sc_f = float(sc) if sc is not None else None
|
||||
except Exception:
|
||||
sc_f = None
|
||||
seen.append(f"cls={cid},score={sc_f}")
|
||||
logger.info(
|
||||
f"[YOLO-BLACK] 有框但类别不在 {class_ids} 内;前几条: {seen}。"
|
||||
f"请核对 TRIANGLE_BLACK_YOLO_CLASS_IDS。"
|
||||
)
|
||||
elif n_cls_ok > 0 and len(out_local) == 0:
|
||||
logger.info(
|
||||
f"[YOLO-BLACK] {n_cls_ok} 个目标类框但边长均 < min_side={min_side},已全部丢弃。"
|
||||
)
|
||||
|
||||
if _save_roi:
|
||||
try:
|
||||
base = (getattr(cfg, "TRIANGLE_BLACK_YOLO_ROI_CROP_DIR", "") or "").strip()
|
||||
if not base:
|
||||
base = os.path.join(
|
||||
getattr(cfg, "PHOTO_DIR", "/tmp") or "/tmp", "stage2_roi"
|
||||
)
|
||||
_draw = bool(
|
||||
getattr(cfg, "TRIANGLE_BLACK_YOLO_SAVE_ROI_DRAW_BOXES", True)
|
||||
)
|
||||
_roi_max_raw = getattr(
|
||||
cfg, "TRIANGLE_BLACK_YOLO_STAGE2_ROI_MAX_IMAGES", None
|
||||
)
|
||||
try:
|
||||
_roi_max = (
|
||||
int(_roi_max_raw)
|
||||
if _roi_max_raw is not None
|
||||
else int(getattr(cfg, "MAX_IMAGES", 1000))
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
_roi_max = int(getattr(cfg, "MAX_IMAGES", 1000))
|
||||
slab_copy = np.ascontiguousarray(slab, dtype=np.uint8).copy()
|
||||
boxes_copy = [tuple(t) for t in out_local]
|
||||
threading.Thread(
|
||||
target=_stage2_roi_crop_save_worker,
|
||||
args=(
|
||||
slab_copy,
|
||||
boxes_copy,
|
||||
rx0,
|
||||
ry0,
|
||||
rw,
|
||||
rh,
|
||||
base,
|
||||
_draw,
|
||||
92,
|
||||
_roi_max,
|
||||
logger,
|
||||
),
|
||||
daemon=True,
|
||||
).start()
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.warning(f"[YOLO-BLACK] 提交异步保存裁切图失败: {e}")
|
||||
|
||||
return out_local
|
||||
50
test/test_audio.py
Normal file
50
test/test_audio.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# test_audio.pyx
|
||||
from maix import audio, time, app, gpio
|
||||
|
||||
def run_player_loop():
|
||||
"""
|
||||
播放控制主循环函数
|
||||
"""
|
||||
# 初始化音频播放器
|
||||
p = audio.Player("/root/gun.wav")
|
||||
p.volume(40)
|
||||
|
||||
# 初始化 GPIO 引脚为输出
|
||||
led = gpio.GPIO("A25", gpio.Mode.OUT)
|
||||
# 设置低电平
|
||||
led.value(0)
|
||||
|
||||
# 主循环
|
||||
while not app.need_exit():
|
||||
led.value(1) # 点亮 LED
|
||||
time.sleep_ms(200) # 保持 200ms
|
||||
led.value(0) # 熄灭 LED
|
||||
p.play() # 播放音频
|
||||
time.sleep_ms(1000) # 等待 1 秒
|
||||
|
||||
print("play finish!")
|
||||
|
||||
|
||||
# 可选:添加一个简单的测试函数
|
||||
def hello():
|
||||
return "Hello from test_audio!"
|
||||
|
||||
|
||||
# 可选:添加一个初始化函数
|
||||
def init_led():
|
||||
"""单独测试 GPIO"""
|
||||
led = gpio.GPIO("A25", gpio.Mode.OUT)
|
||||
led.value(0)
|
||||
return "LED initialized"
|
||||
|
||||
|
||||
# 可选:添加一个播放测试函数
|
||||
def test_play():
|
||||
"""单独测试音频播放"""
|
||||
p = audio.Player("/root/gun.wav")
|
||||
p.volume(50)
|
||||
p.play()
|
||||
return "Playing..."
|
||||
|
||||
|
||||
run_player_loop()
|
||||
25
test/test_button.py
Normal file
25
test/test_button.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from maix import audio, time, app,gpio
|
||||
|
||||
|
||||
# button1 = gpio.GPIO("ADC", gpio.Mode.IN)
|
||||
button3 = gpio.GPIO("A26", gpio.Mode.IN) # 可用
|
||||
button2 = gpio.GPIO("A16", gpio.Mode.IN)
|
||||
#设置低电平
|
||||
from maix.peripheral import adc
|
||||
channel = 0
|
||||
res_bit = adc.RES_BIT_12
|
||||
_adc_obj = adc.ADC(channel, res_bit)
|
||||
|
||||
|
||||
while not app.need_exit():
|
||||
# print(f"b1: {button1.value()}")
|
||||
|
||||
print(f"b2: {button2.value()}")
|
||||
|
||||
# print(_adc_obj.read_vol())
|
||||
print(f"b3: {button3.value()}")
|
||||
time.sleep_ms(50)
|
||||
|
||||
# time.sleep_ms(1000)
|
||||
|
||||
|
||||
36
test/test_camera_rtsp.py
Normal file
36
test/test_camera_rtsp.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# from maix import time, rtsp, camera, image
|
||||
|
||||
# # 1. 初始化摄像头(注意:RTSP需要NV21格式)
|
||||
# # 分辨率可以根据需要调整,如 640x480 或 1280x720
|
||||
# cam = camera.Camera(640, 480, image.Format.FMT_YVU420SP)
|
||||
|
||||
# # 2. 创建并启动RTSP服务器
|
||||
# server = rtsp.Rtsp()
|
||||
# server.bind_camera(cam)
|
||||
# server.start()
|
||||
|
||||
# # 3. 打印出访问地址,例如: rtsp://192.168.xxx.xxx:8554/live
|
||||
# print("RTSP 流地址:", server.get_url())
|
||||
|
||||
# # 4. 保持服务运行
|
||||
# while True:
|
||||
# time.sleep(1)
|
||||
|
||||
|
||||
|
||||
from maix import camera, time, app, http, image
|
||||
|
||||
# 初始化相机,注意格式要用 FMT_RGB888(JPEG 编码需要 RGB 输入)
|
||||
cam = camera.Camera(640, 480, image.Format.FMT_RGB888)
|
||||
|
||||
# 创建 JPEG 流服务器
|
||||
stream = http.JpegStreamer()
|
||||
stream.start()
|
||||
|
||||
print("RTSP 替代方案 - HTTP JPEG 流地址: http://{}:{}".format(stream.host(), stream.port()))
|
||||
print("请在浏览器或 OpenCV 中访问: http://<MaixCAM_IP>:8000/stream")
|
||||
|
||||
while not app.need_exit():
|
||||
img = cam.read()
|
||||
jpg = img.to_jpeg() # 将 RGB 图像编码为 JPEG
|
||||
stream.write(jpg) # 推送到 HTTP 客户端
|
||||
20
test/test_cammera.py
Normal file
20
test/test_cammera.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# test_camera.py
|
||||
from maix import camera, display, time
|
||||
|
||||
try:
|
||||
print("Initializing camera...")
|
||||
cam = camera.Camera(640,480)
|
||||
# cam = camera.Camera(1280,720)
|
||||
# cam.get_exposure_us()
|
||||
# print("Camera exposure: ", cam.get_exposure_us())
|
||||
print("Camera initialized successfully!")
|
||||
|
||||
disp = display.Display()
|
||||
|
||||
while True:
|
||||
frame = cam.read()
|
||||
disp.show(frame)
|
||||
time.sleep_ms(50)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
620
test/test_decect_circle.py
Normal file
620
test/test_decect_circle.py
Normal file
@@ -0,0 +1,620 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
离线测试脚本:直接复用 detect_circle 逻辑进行测试
|
||||
运行环境:MaixPy (Sipeed MAIX)
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
# import time
|
||||
from maix import image,time
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# ==================== 全局配置 (与 test_main.py 保持一致) ====================
|
||||
REAL_RADIUS_CM = 20 # 靶心实际半径(厘米)
|
||||
|
||||
# ==================== 复制的核心算法 ====================
|
||||
# 注意:这里直接复制了 detect_circle 的逻辑,避免 import main 导致的冲突
|
||||
|
||||
|
||||
def detect_circle_v3(frame, laser_point=None):
|
||||
"""检测图像中的靶心(优先清晰轮廓,其次黄色区域)- 返回椭圆参数版本
|
||||
增加红色圆圈检测,验证黄色圆圈是否为真正的靶心
|
||||
如果提供 laser_point,会选择最接近激光点的目标
|
||||
|
||||
Args:
|
||||
frame: 图像帧
|
||||
laser_point: 激光点坐标 (x, y),用于多目标场景下的目标选择
|
||||
|
||||
Returns:
|
||||
(result_img, best_center, best_radius, method, best_radius1, ellipse_params)
|
||||
"""
|
||||
img_cv = image.image2cv(frame, False, False)
|
||||
|
||||
best_center = best_radius = best_radius1 = method = None
|
||||
ellipse_params = None
|
||||
|
||||
# HSV 黄色掩码检测(模糊靶心)
|
||||
hsv = cv2.cvtColor(img_cv, cv2.COLOR_RGB2HSV)
|
||||
h, s, v = cv2.split(hsv)
|
||||
|
||||
# 调整饱和度策略:稍微增强,不要过度
|
||||
s = np.clip(s * 1.1, 0, 255).astype(np.uint8)
|
||||
|
||||
hsv = cv2.merge((h, s, v))
|
||||
|
||||
# 放宽 HSV 阈值范围(针对模糊图像的关键调整)
|
||||
lower_yellow = np.array([7, 80, 0]) # 饱和度下限降低,捕捉淡黄色
|
||||
upper_yellow = np.array([32, 255, 255]) # 亮度上限拉满
|
||||
|
||||
mask_yellow = cv2.inRange(hsv, lower_yellow, upper_yellow)
|
||||
|
||||
# 调整形态学操作
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
||||
mask_yellow = cv2.morphologyEx(mask_yellow, cv2.MORPH_CLOSE, kernel)
|
||||
|
||||
contours_yellow, _ = cv2.findContours(mask_yellow, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
# 存储所有有效的黄色-红色组合
|
||||
valid_targets = []
|
||||
|
||||
if contours_yellow:
|
||||
for cnt_yellow in contours_yellow:
|
||||
area = cv2.contourArea(cnt_yellow)
|
||||
perimeter = cv2.arcLength(cnt_yellow, True)
|
||||
|
||||
# 计算圆度
|
||||
if perimeter > 0:
|
||||
circularity = (4 * np.pi * area) / (perimeter * perimeter)
|
||||
else:
|
||||
circularity = 0
|
||||
|
||||
logger = get_logger()
|
||||
if area > 50 and circularity > 0.7:
|
||||
if logger:
|
||||
logger.info(f"[target] -> 面积:{area}, 圆度:{circularity:.2f}")
|
||||
# 尝试拟合椭圆
|
||||
yellow_center = None
|
||||
yellow_radius = None
|
||||
yellow_ellipse = None
|
||||
|
||||
if len(cnt_yellow) >= 5:
|
||||
(x, y), (width, height), angle = cv2.fitEllipse(cnt_yellow)
|
||||
yellow_ellipse = ((x, y), (width, height), angle)
|
||||
axes_minor = min(width, height)
|
||||
radius = axes_minor / 2
|
||||
yellow_center = (int(x), int(y))
|
||||
yellow_radius = int(radius)
|
||||
else:
|
||||
(x, y), radius = cv2.minEnclosingCircle(cnt_yellow)
|
||||
yellow_center = (int(x), int(y))
|
||||
yellow_radius = int(radius)
|
||||
yellow_ellipse = None
|
||||
|
||||
# 如果检测到黄色圆圈,再检测红色圆圈进行验证
|
||||
if yellow_center and yellow_radius:
|
||||
# HSV 红色掩码检测(红色在HSV中跨越0度,需要两个范围)
|
||||
# 红色范围1: 0-10度(接近0度的红色)
|
||||
lower_red1 = np.array([0, 80, 0])
|
||||
upper_red1 = np.array([10, 255, 255])
|
||||
mask_red1 = cv2.inRange(hsv, lower_red1, upper_red1)
|
||||
|
||||
# 红色范围2: 170-180度(接近180度的红色)
|
||||
lower_red2 = np.array([170, 80, 0])
|
||||
upper_red2 = np.array([180, 255, 255])
|
||||
mask_red2 = cv2.inRange(hsv, lower_red2, upper_red2)
|
||||
|
||||
# 合并两个红色掩码
|
||||
mask_red = cv2.bitwise_or(mask_red1, mask_red2)
|
||||
|
||||
# 形态学操作
|
||||
kernel_red = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
||||
mask_red = cv2.morphologyEx(mask_red, cv2.MORPH_CLOSE, kernel_red)
|
||||
|
||||
contours_red, _ = cv2.findContours(mask_red, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
found_valid_red = False
|
||||
|
||||
if contours_red:
|
||||
# 找到所有符合条件的红色圆圈
|
||||
for cnt_red in contours_red:
|
||||
area_red = cv2.contourArea(cnt_red)
|
||||
perimeter_red = cv2.arcLength(cnt_red, True)
|
||||
|
||||
if perimeter_red > 0:
|
||||
circularity_red = (4 * np.pi * area_red) / (perimeter_red * perimeter_red)
|
||||
else:
|
||||
circularity_red = 0
|
||||
|
||||
# 红色圆圈也应该有一定的圆度
|
||||
if area_red > 50 and circularity_red > 0.6:
|
||||
# 计算红色圆圈的中心和半径
|
||||
if len(cnt_red) >= 5:
|
||||
(x_red, y_red), (w_red, h_red), angle_red = cv2.fitEllipse(cnt_red)
|
||||
radius_red = min(w_red, h_red) / 2
|
||||
red_center = (int(x_red), int(y_red))
|
||||
red_radius = int(radius_red)
|
||||
else:
|
||||
(x_red, y_red), radius_red = cv2.minEnclosingCircle(cnt_red)
|
||||
red_center = (int(x_red), int(y_red))
|
||||
red_radius = int(radius_red)
|
||||
|
||||
# 计算黄色和红色圆心的距离
|
||||
if red_center:
|
||||
dx = yellow_center[0] - red_center[0]
|
||||
dy = yellow_center[1] - red_center[1]
|
||||
distance = np.sqrt(dx*dx + dy*dy)
|
||||
|
||||
# 圆心距离阈值:应该小于黄色半径的某个倍数(比如1.5倍)
|
||||
max_distance = yellow_radius * 1.5
|
||||
|
||||
# 红色圆圈应该比黄色圆圈大(外圈)
|
||||
if distance < max_distance and red_radius > yellow_radius * 0.8:
|
||||
found_valid_red = True
|
||||
logger = get_logger()
|
||||
if logger:
|
||||
logger.info(f"[target] -> 找到匹配的红圈: 黄心({yellow_center}), 红心({red_center}), 距离:{distance:.1f}, 黄半径:{yellow_radius}, 红半径:{red_radius}")
|
||||
|
||||
# 记录这个有效目标
|
||||
valid_targets.append({
|
||||
'center': yellow_center,
|
||||
'radius': yellow_radius,
|
||||
'ellipse': yellow_ellipse,
|
||||
'area': area
|
||||
})
|
||||
break
|
||||
|
||||
if not found_valid_red:
|
||||
logger = get_logger()
|
||||
if logger:
|
||||
logger.debug("Debug -> 未找到匹配的红色圆圈,可能是误识别")
|
||||
|
||||
# 从所有有效目标中选择最佳目标
|
||||
if valid_targets:
|
||||
if laser_point:
|
||||
# 如果有激光点,选择最接近激光点的目标
|
||||
best_target = None
|
||||
min_distance = float('inf')
|
||||
for target in valid_targets:
|
||||
dx = target['center'][0] - laser_point[0]
|
||||
dy = target['center'][1] - laser_point[1]
|
||||
distance = np.sqrt(dx*dx + dy*dy)
|
||||
if distance < min_distance:
|
||||
min_distance = distance
|
||||
best_target = target
|
||||
if best_target:
|
||||
best_center = best_target['center']
|
||||
best_radius = best_target['radius']
|
||||
ellipse_params = best_target['ellipse']
|
||||
method = "v3_ellipse_red_validated_laser_selected"
|
||||
best_radius1 = best_radius * 5
|
||||
else:
|
||||
# 如果没有激光点,选择面积最大的目标
|
||||
best_target = max(valid_targets, key=lambda t: t['area'])
|
||||
best_center = best_target['center']
|
||||
best_radius = best_target['radius']
|
||||
ellipse_params = best_target['ellipse']
|
||||
method = "v3_ellipse_red_validated"
|
||||
best_radius1 = best_radius * 5
|
||||
|
||||
result_img = image.cv2image(img_cv, False, False)
|
||||
return result_img, best_center, best_radius, method, best_radius1, ellipse_params
|
||||
|
||||
def detect_circle(frame):
|
||||
"""检测图像中的靶心(优先清晰轮廓,其次黄色区域)"""
|
||||
img_cv = image.image2cv(frame, False, False)
|
||||
# gray = cv2.cvtColor(img_cv, cv2.COLOR_RGB2GRAY)
|
||||
# blurred = cv2.GaussianBlur(gray, (5, 5), 0)
|
||||
# edged = cv2.Canny(blurred, 50, 150)
|
||||
# kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
||||
# ceroded = cv2.erode(cv2.dilate(edged, kernel), kernel)
|
||||
|
||||
# contours, _ = cv2.findContours(ceroded, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
# best_center = best_radius = best_radius1 = method = None
|
||||
|
||||
# hsv = cv2.cvtColor(img_cv, cv2.COLOR_RGB2HSV)
|
||||
# h, s, v = cv2.split(hsv)
|
||||
# s = np.clip(s * 2, 0, 255).astype(np.uint8)
|
||||
# hsv = cv2.merge((h, s, v))
|
||||
# lower_yellow = np.array([7, 80, 0])
|
||||
# upper_yellow = np.array([32, 255, 182])
|
||||
# mask = cv2.inRange(hsv, lower_yellow, upper_yellow)
|
||||
# kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
||||
# mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
||||
# mask = cv2.morphologyEx(mask, cv2.MORPH_DILATE, kernel)
|
||||
# contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
# if contours:
|
||||
# largest = max(contours, key=cv2.contourArea)
|
||||
# if cv2.contourArea(largest) > 50:
|
||||
# (x, y), radius = cv2.minEnclosingCircle(largest)
|
||||
# best_center = (int(x), int(y))
|
||||
# best_radius = int(radius)
|
||||
# best_radius1 = radius * 5
|
||||
# method = "v2"
|
||||
|
||||
# auto
|
||||
# R:31 M:v2 D:2.410110127692767
|
||||
# hsv = cv2.cvtColor(img_cv, cv2.COLOR_RGB2HSV)
|
||||
# h, s, v = cv2.split(hsv)
|
||||
|
||||
# # 1. 增强饱和度(模糊照片需要更强的增强)
|
||||
# s = np.clip(s * 2.5, 0, 255).astype(np.uint8) # 从2.0改为2.5
|
||||
|
||||
# # 2. 增强亮度(模糊照片可能偏暗)
|
||||
# v = np.clip(v * 1.2, 0, 255).astype(np.uint8) # 新增:提升亮度
|
||||
|
||||
# hsv = cv2.merge((h, s, v))
|
||||
|
||||
# # 3. 放宽HSV颜色范围(特别是模糊照片)
|
||||
# # 降低饱和度下限,提高亮度上限
|
||||
# lower_yellow = np.array([5, 50, 30]) # H:5-35, S:50-255, V:30-255
|
||||
# upper_yellow = np.array([35, 255, 255])
|
||||
|
||||
# mask = cv2.inRange(hsv, lower_yellow, upper_yellow)
|
||||
|
||||
# # 4. 增强形态学操作(连接被分割的区域)
|
||||
# kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
||||
# kernel_large = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9)) # 更大的核
|
||||
|
||||
# # 先开运算去除噪声
|
||||
# mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel_small)
|
||||
# # 多次膨胀连接区域(模糊照片需要更多膨胀)
|
||||
# mask = cv2.dilate(mask, kernel_large, iterations=2) # 增加迭代次数
|
||||
# mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel_large) # 闭运算填充空洞
|
||||
|
||||
# contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
# if contours:
|
||||
# largest = max(contours, key=cv2.contourArea)
|
||||
# area = cv2.contourArea(largest)
|
||||
# if area > 50:
|
||||
# # 5. 使用面积计算等效半径(更准确)
|
||||
# equivalent_radius = np.sqrt(area / np.pi)
|
||||
|
||||
# # 6. 同时使用minEnclosingCircle作为备选(取较大值)
|
||||
# (x, y), enclosing_radius = cv2.minEnclosingCircle(largest)
|
||||
|
||||
# # 取两者中的较大值,确保不遗漏
|
||||
# radius = max(equivalent_radius, enclosing_radius)
|
||||
|
||||
# best_center = (int(x), int(y))
|
||||
# best_radius = int(radius)
|
||||
# best_radius1 = radius * 5
|
||||
# method = "v2"
|
||||
|
||||
# codegee
|
||||
# R:24 M:v2 D:3.061493895819174
|
||||
# R:22 M:v2 D:3.3644971681267077 np.clip(s * 1.1, 0, 255)
|
||||
hsv = cv2.cvtColor(img_cv, cv2.COLOR_RGB2HSV)
|
||||
h, s, v = cv2.split(hsv)
|
||||
|
||||
# 2. 调整饱和度策略:
|
||||
# 不要暴力翻倍,可以尝试稍微增强,或者使用 CLAHE 增强亮度/对比度
|
||||
# 这里我们稍微增加一点饱和度,并确保不溢出
|
||||
s = np.clip(s * 1.1, 0, 255).astype(np.uint8)
|
||||
# 对亮度通道 v 也可以做一点 CLAHE 处理来增强对比度(可选)
|
||||
# clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
|
||||
# v = clahe.apply(v)
|
||||
|
||||
hsv = cv2.merge((h, s, v))
|
||||
|
||||
# 3. 放宽 HSV 阈值范围(针对模糊图像的关键调整)
|
||||
# 降低 S 的下限 (80 -> 35),提高 V 的上限 (182 -> 255)
|
||||
lower_yellow = np.array([7, 80, 0]) # 饱和度下限降低,捕捉淡黄色
|
||||
upper_yellow = np.array([32, 255, 255]) # 亮度上限拉满
|
||||
|
||||
mask = cv2.inRange(hsv, lower_yellow, upper_yellow)
|
||||
|
||||
# 4. 调整形态学操作
|
||||
# 去掉 MORPH_OPEN,因为它会减小面积。
|
||||
# 使用 MORPH_CLOSE (先膨胀后腐蚀) 来填充内部小黑洞,连接近邻区域
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
||||
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
||||
# 再进行一次膨胀,确保边缘被包含进来
|
||||
# mask = cv2.dilate(mask, kernel, iterations=1)
|
||||
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
if contours:
|
||||
largest = max(contours, key=cv2.contourArea)
|
||||
|
||||
# 这里可以适当降低面积阈值,或者保持不变
|
||||
if cv2.contourArea(largest) > 50:
|
||||
# (x, y), radius = cv2.minEnclosingCircle(largest)
|
||||
# best_center = (int(x), int(y))
|
||||
# best_radius = int(radius)
|
||||
|
||||
# --- 核心修改开始 ---
|
||||
# 1. 尝试拟合椭圆 (需要轮廓点至少为5个)
|
||||
if len(largest) >= 5:
|
||||
# 返回值: ((中心x, 中心y), (长轴, 短轴), 旋转角度)
|
||||
(x, y), (axes_major, axes_minor), angle = cv2.fitEllipse(largest)
|
||||
|
||||
# 2. 计算半径
|
||||
# 选项A:取长短轴的平均值 (比较稳健)
|
||||
# radius = (axes_major + axes_minor) / 4
|
||||
|
||||
# 选项B:直接取短轴的一半 (抗模糊最强,推荐)
|
||||
radius = axes_minor / 2
|
||||
|
||||
best_center = (int(x), int(y))
|
||||
best_radius = int(radius)
|
||||
method = "v2_ellipse"
|
||||
else:
|
||||
# 如果点太少无法拟合椭圆,降级回 minEnclosingCircle
|
||||
(x, y), radius = cv2.minEnclosingCircle(largest)
|
||||
best_center = (int(x), int(y))
|
||||
best_radius = int(radius)
|
||||
method = "v2"
|
||||
# --- 核心修改结束 ---
|
||||
|
||||
# 你的后续逻辑
|
||||
best_radius1 = radius * 5
|
||||
|
||||
|
||||
# operas 4.5
|
||||
# R:25 M:v2 D:2.9554872521538527
|
||||
# hsv = cv2.cvtColor(img_cv, cv2.COLOR_RGB2HSV)
|
||||
# h, s, v = cv2.split(hsv)
|
||||
|
||||
# # 1. 适度增强饱和度(不要过度,否则噪声也会增强)
|
||||
# s = np.clip(s * 1.5, 0, 255).astype(np.uint8)
|
||||
# hsv = cv2.merge((h, s, v))
|
||||
|
||||
# # 2. 放宽 HSV 阈值范围(关键改动)
|
||||
# # - 饱和度下限从 80 降到 40(捕捉淡黄色)
|
||||
# # - 亮度上限从 182 提高到 255(允许更亮的黄色)
|
||||
# lower_yellow = np.array([7, 40, 30])
|
||||
# upper_yellow = np.array([35, 255, 255])
|
||||
|
||||
# mask = cv2.inRange(hsv, lower_yellow, upper_yellow)
|
||||
|
||||
# # 3. 调整形态学操作:用 CLOSE 替代 OPEN
|
||||
# # CLOSE(先膨胀后腐蚀):填充内部空洞,连接相邻区域
|
||||
# # OPEN(先腐蚀后膨胀):会缩小区域,不适合模糊图像
|
||||
# kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) # 稍大的核
|
||||
# mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
||||
# mask = cv2.dilate(mask, kernel, iterations=1) # 额外膨胀,确保边缘被包含
|
||||
|
||||
# contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
# if contours:
|
||||
# largest = max(contours, key=cv2.contourArea)
|
||||
# if cv2.contourArea(largest) > 50:
|
||||
# (x, y), radius = cv2.minEnclosingCircle(largest)
|
||||
# best_center = (int(x), int(y))
|
||||
# best_radius = int(radius)
|
||||
# best_radius1 = radius * 5
|
||||
# method = "v2"
|
||||
|
||||
# # --- 新增:将 Mask 叠加到原图上用于调试 ---
|
||||
# # 创建一个彩色掩码(红色通道为255,其他为0)
|
||||
# mask_overlay = np.zeros_like(img_cv)
|
||||
# mask_overlay[:, :, 2] = mask # 将掩码放在红色通道 (BGR中的R)
|
||||
#
|
||||
# cv2.addWeighted(img_cv, 0.6, mask_overlay, 0.4, 0, img_cv)
|
||||
|
||||
result_img = image.cv2image(img_cv, False, False)
|
||||
return result_img, best_center, best_radius, method, best_radius1
|
||||
|
||||
|
||||
def detect_circle_v2(frame):
|
||||
"""检测图像中的靶心(优先清晰轮廓,其次黄色区域)- 返回椭圆参数版本"""
|
||||
global REAL_RADIUS_CM
|
||||
img_cv = image.image2cv(frame, False, False)
|
||||
|
||||
best_center = best_radius = best_radius1 = method = None
|
||||
ellipse_params = None # 存储椭圆参数 ((x, y), (axes_major, axes_minor), angle)
|
||||
|
||||
# HSV 黄色掩码检测(模糊靶心)
|
||||
hsv = cv2.cvtColor(img_cv, cv2.COLOR_RGB2HSV)
|
||||
h, s, v = cv2.split(hsv)
|
||||
|
||||
# 调整饱和度策略:稍微增强,不要过度
|
||||
s = np.clip(s * 1.1, 0, 255).astype(np.uint8)
|
||||
|
||||
hsv = cv2.merge((h, s, v))
|
||||
|
||||
# 放宽 HSV 阈值范围(针对模糊图像的关键调整)
|
||||
lower_yellow = np.array([7, 80, 0]) # 饱和度下限降低,捕捉淡黄色
|
||||
upper_yellow = np.array([32, 255, 255]) # 亮度上限拉满
|
||||
|
||||
mask = cv2.inRange(hsv, lower_yellow, upper_yellow)
|
||||
|
||||
# 调整形态学操作
|
||||
# 使用 MORPH_CLOSE (先膨胀后腐蚀) 来填充内部小黑洞,连接近邻区域
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
||||
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
||||
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
if contours:
|
||||
largest = max(contours, key=cv2.contourArea)
|
||||
|
||||
if cv2.contourArea(largest) > 50:
|
||||
# 尝试拟合椭圆 (需要轮廓点至少为5个)
|
||||
if len(largest) >= 5:
|
||||
# 返回值: ((中心x, 中心y), (width, height), 旋转角度)
|
||||
# 注意:width 和 height 是外接矩形的尺寸,不是长轴和短轴
|
||||
(x, y), (width, height), angle = cv2.fitEllipse(largest)
|
||||
|
||||
# 保存椭圆参数(保持原始顺序,用于绘制)
|
||||
ellipse_params = ((x, y), (width, height), angle)
|
||||
|
||||
# 计算半径:使用较小的尺寸作为短轴
|
||||
axes_minor = min(width, height)
|
||||
radius = axes_minor / 2
|
||||
|
||||
best_center = (int(x), int(y))
|
||||
best_radius = int(radius)
|
||||
method = "v2_ellipse"
|
||||
else:
|
||||
# 如果点太少无法拟合椭圆,降级回 minEnclosingCircle
|
||||
(x, y), radius = cv2.minEnclosingCircle(largest)
|
||||
best_center = (int(x), int(y))
|
||||
best_radius = int(radius)
|
||||
method = "v2"
|
||||
ellipse_params = None # 圆形,没有椭圆参数
|
||||
|
||||
best_radius1 = radius * 5
|
||||
|
||||
result_img = image.cv2image(img_cv, False, False)
|
||||
return result_img, best_center, best_radius, method, best_radius1, ellipse_params
|
||||
|
||||
# ==================== 测试逻辑 ====================
|
||||
|
||||
def run_offline_test(image_path):
|
||||
"""读取图片,检测圆,绘制结果,保存图片"""
|
||||
|
||||
# 1. 检查文件是否存在
|
||||
if not os.path.exists(image_path):
|
||||
print(f"[ERROR] 找不到图片文件: {image_path}")
|
||||
return
|
||||
|
||||
# 2. 使用 maix.image 读取图片 (适配 MaixPy v4)
|
||||
try:
|
||||
# 使用 image.load 读取文件,返回 Image 对象
|
||||
img = image.load(image_path)
|
||||
print(f"[INFO] 成功读取图片: {image_path} (尺寸: {img.width()}x{img.height()})")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 读取图片失败: {e}")
|
||||
print("提示:请确认 MaixPy 版本是否为 v4,且图片路径正确。")
|
||||
return
|
||||
|
||||
|
||||
# 3. 调用 detect_circle_v2 函数
|
||||
print("[INFO] 正在调用 detect_circle_v2 进行检测...")
|
||||
start_time = time.ticks_ms()
|
||||
|
||||
result_img, center, radius, method, radius1, ellipse_params = detect_circle_v3(img)
|
||||
|
||||
cost_time = time.ticks_ms() - start_time
|
||||
print(f"[INFO] 检测完成,耗时: {cost_time}ms")
|
||||
print(f" 结果 -> 圆心: {center}, 半径: {radius}, 方法: {method}")
|
||||
if ellipse_params:
|
||||
(ell_center, (width, height), angle) = ellipse_params
|
||||
print(f" 椭圆 -> 中心: ({ell_center[0]:.1f}, {ell_center[1]:.1f}), 长轴: {max(width, height):.1f}, 短轴: {min(width, height):.1f}, 角度: {angle:.1f}°")
|
||||
|
||||
# 4. 绘制辅助线(可选,用于调试)
|
||||
if center and radius:
|
||||
# 为了绘制椭圆,需要转换回 cv2 图像
|
||||
img_cv = image.image2cv(result_img, False, False)
|
||||
|
||||
cx, cy = center
|
||||
|
||||
# 如果有椭圆参数,绘制椭圆
|
||||
if ellipse_params:
|
||||
(ell_center, (width, height), angle) = ellipse_params
|
||||
cx_ell, cy_ell = int(ell_center[0]), int(ell_center[1])
|
||||
|
||||
# 确定长轴和短轴
|
||||
if width >= height:
|
||||
# width 是长轴,height 是短轴
|
||||
axes_major = width
|
||||
axes_minor = height
|
||||
major_angle = angle # 长轴角度就是 angle
|
||||
minor_angle = angle + 90 # 短轴角度 = 长轴角度 + 90度
|
||||
else:
|
||||
# height 是长轴,width 是短轴
|
||||
axes_major = height
|
||||
axes_minor = width
|
||||
major_angle = angle + 90 # 长轴角度 = width角度 + 90度
|
||||
minor_angle = angle # 短轴角度就是 angle
|
||||
|
||||
# 使用 OpenCV 绘制椭圆(绿色,线宽2)
|
||||
cv2.ellipse(img_cv,
|
||||
(cx_ell, cy_ell), # 中心点
|
||||
(int(width/2), int(height/2)), # 半宽、半高
|
||||
angle, # 旋转角度(OpenCV需要原始angle)
|
||||
0, 360, # 起始和结束角度
|
||||
(0, 255, 0), # 绿色 (RGB格式)
|
||||
2) # 线宽
|
||||
|
||||
# 绘制椭圆中心点(红色)
|
||||
cv2.circle(img_cv, (cx_ell, cy_ell), 3, (255, 0, 0), -1)
|
||||
|
||||
import math
|
||||
# 绘制短轴(蓝色线条)
|
||||
minor_length = axes_minor / 2
|
||||
minor_angle_rad = math.radians(minor_angle)
|
||||
dx_minor = minor_length * math.cos(minor_angle_rad)
|
||||
dy_minor = minor_length * math.sin(minor_angle_rad)
|
||||
pt1_minor = (int(cx_ell - dx_minor), int(cy_ell - dy_minor))
|
||||
pt2_minor = (int(cx_ell + dx_minor), int(cy_ell + dy_minor))
|
||||
cv2.line(img_cv, pt1_minor, pt2_minor, (0, 0, 255), 2) # 蓝色 (RGB格式)
|
||||
else:
|
||||
# 如果没有椭圆参数,绘制圆形(红色)
|
||||
cv2.circle(img_cv, (cx, cy), radius, (0, 0, 255), 2)
|
||||
cv2.circle(img_cv, (cx, cy), 2, (0, 0, 255), -1)
|
||||
|
||||
# 转换回 maix image
|
||||
result_img = image.cv2image(img_cv, False, False)
|
||||
|
||||
# 定义颜色对象用于文字
|
||||
try:
|
||||
color_black = image.Color.from_rgb(0,0,0)
|
||||
except AttributeError:
|
||||
color_black = image.Color(0,0,0)
|
||||
|
||||
# D. 添加文字信息
|
||||
FOCAL_LENGTH_PIX = 1900
|
||||
d = (REAL_RADIUS_CM * FOCAL_LENGTH_PIX) / radius1 / 100.0
|
||||
info_str = f"R:{radius} M:{method} D:{d:.2f}"
|
||||
print(info_str)
|
||||
|
||||
# 计算文字位置,防止超出图片边界
|
||||
r_outer = int(radius * 11.0) if radius else 100
|
||||
text_y = cy - r_outer - 20 if cy > r_outer + 20 else cy + r_outer + 20
|
||||
|
||||
# 调用 draw_string
|
||||
result_img.draw_string(0, 0, info_str, color=color_black, scale=1.0)
|
||||
|
||||
|
||||
# 5. 保存结果图片
|
||||
output_path = image_path.replace(".bmp", "_result.bmp")
|
||||
output_path = image_path.replace(".jpg", "_result.jpg")
|
||||
try:
|
||||
result_img.save(output_path, quality=100)
|
||||
print(f"[SUCCESS] 结果已保存至: {output_path}")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 保存图片失败: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ================= 配置区域 =================
|
||||
|
||||
# 1. 设置要测试的图片路径
|
||||
# 建议将图片放在与脚本同级目录,或者使用绝对路径
|
||||
TARGET_IMAGE = "/root/phot/None_314_258_0_0041.bmp"
|
||||
|
||||
# TARGET_DIR = "/root/phot_test2" # 修改为你想要读取的目录路径
|
||||
|
||||
# 支持的图片格式
|
||||
IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.bmp']
|
||||
|
||||
# ================= 执行区域 =================
|
||||
if 'TARGET_DIR' in locals():
|
||||
# 读取目录下所有图片文件,过滤掉 _result.jpg 后缀的文件
|
||||
image_files = []
|
||||
if os.path.exists(TARGET_DIR) and os.path.isdir(TARGET_DIR):
|
||||
for filename in os.listdir(TARGET_DIR):
|
||||
# 检查文件扩展名
|
||||
if any(filename.lower().endswith(ext) for ext in IMAGE_EXTENSIONS):
|
||||
# 过滤掉 _result.jpg 后缀的文件
|
||||
if not filename.endswith('_result.jpg'):
|
||||
filepath = os.path.join(TARGET_DIR, filename)
|
||||
if os.path.isfile(filepath):
|
||||
image_files.append(filepath)
|
||||
|
||||
# 按文件名排序(可选)
|
||||
image_files.sort()
|
||||
|
||||
print(f"[INFO] 在目录 {TARGET_DIR} 中找到 {len(image_files)} 张图片")
|
||||
|
||||
# 处理每张图片
|
||||
for img_path in image_files:
|
||||
print(f"\n{'='*10} 开始处理: {img_path} {'='*10}")
|
||||
run_offline_test(img_path)
|
||||
else:
|
||||
print(f"[ERROR] 目录不存在或不是有效目录: {TARGET_DIR}")
|
||||
|
||||
else:
|
||||
run_offline_test(TARGET_IMAGE)
|
||||
61
test/test_i2c.py
Normal file
61
test/test_i2c.py
Normal file
@@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python3
|
||||
# test_i2c_devices.py
|
||||
|
||||
import os
|
||||
from maix import i2c
|
||||
|
||||
def list_i2c_devices():
|
||||
"""List available I2C device nodes"""
|
||||
print("Available I2C devices:")
|
||||
|
||||
# Check /dev directory
|
||||
try:
|
||||
dev_files = os.listdir("/dev")
|
||||
i2c_devices = [f for f in dev_files if "i2c" in f]
|
||||
if i2c_devices:
|
||||
for dev in sorted(i2c_devices):
|
||||
print(f" /dev/{dev}")
|
||||
else:
|
||||
print(" No /dev/i2c-* devices found!")
|
||||
except Exception as e:
|
||||
print(f" Error listing /dev: {e}")
|
||||
|
||||
def try_i2c_bus(bus_num):
|
||||
"""Try to initialize an I2C bus"""
|
||||
try:
|
||||
bus = i2c.I2C(bus_num, i2c.Mode.MASTER)
|
||||
print(f" I2C bus {bus_num}: OK")
|
||||
return True
|
||||
except RuntimeError as e:
|
||||
print(f" I2C bus {bus_num}: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f" I2C bus {bus_num}: Unexpected error: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("I2C Device Diagnostic")
|
||||
print("=" * 60)
|
||||
|
||||
# List kernel devices
|
||||
list_i2c_devices()
|
||||
|
||||
# Try common bus numbers
|
||||
print("\nTesting I2C buses:")
|
||||
working_buses = []
|
||||
for bus_num in range(10):
|
||||
if try_i2c_bus(bus_num):
|
||||
working_buses.append(bus_num)
|
||||
|
||||
print(f"\nWorking buses: {working_buses}")
|
||||
|
||||
if not working_buses:
|
||||
print("\nERROR: No I2C buses available!")
|
||||
print("Possible causes:")
|
||||
print(" 1. I2C kernel driver not loaded")
|
||||
print(" 2. Device tree doesn't enable I2C")
|
||||
print(" 3. Different kernel version with different device naming")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
246
test/test_laser.py
Normal file
246
test/test_laser.py
Normal file
@@ -0,0 +1,246 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
M01激光测距模块测试脚本 - 修正版
|
||||
基于文档中的完整命令示例
|
||||
"""
|
||||
|
||||
from maix import uart, pinmap, time
|
||||
import binascii
|
||||
|
||||
# ==================== 配置 ====================
|
||||
UART_PORT = "/dev/ttyS1"
|
||||
BAUDRATE = 9600
|
||||
|
||||
# 初始化串口
|
||||
try:
|
||||
pinmap.set_pin_function("A18", "UART1_RX")
|
||||
pinmap.set_pin_function("A19", "UART1_TX")
|
||||
laser_uart = uart.UART(UART_PORT, BAUDRATE)
|
||||
print("✅ 硬件初始化完成")
|
||||
except Exception as e:
|
||||
print(f"❌ 初始化失败: {e}")
|
||||
exit(1)
|
||||
|
||||
# ==================== 根据文档的完整命令集 ====================
|
||||
# 1. 激光开关(文档2.3.10,已验证可用)
|
||||
LASER_ON_CMD = bytes([0xAA, 0x00, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x01, 0xC1])
|
||||
LASER_OFF_CMD = bytes([0xAA, 0x00, 0x01, 0xBE, 0x00, 0x01, 0x00, 0x00, 0xC0])
|
||||
|
||||
# 2. 尝试不同的测距命令格式
|
||||
TEST_COMMANDS = [
|
||||
# 格式1:文档2.3.12的单次测量(您测试失败的)
|
||||
{
|
||||
"name": "单次测量 (0x0020)",
|
||||
"cmd": bytes([0xAA, 0x00, 0x00, 0x20, 0x00, 0x01, 0x00, 0x00, 0x21]),
|
||||
"desc": "文档2.3.12 示例命令"
|
||||
},
|
||||
# 格式2:文档2.3.7的读取测量结果
|
||||
{
|
||||
"name": "读取测量结果 (0x0022)",
|
||||
"cmd": bytes([0xAA, 0x80, 0x00, 0x22, 0xA2]),
|
||||
"desc": "文档2.3.7 读取测量结果"
|
||||
},
|
||||
# 格式3:文档2.3.13的快速测量
|
||||
{
|
||||
"name": "快速测量 (0x0022带数据)",
|
||||
"cmd": bytes([0xAA, 0x00, 0x00, 0x22, 0x00, 0x01, 0x00, 0x00, 0x23]),
|
||||
"desc": "文档2.3.13 快速测量"
|
||||
},
|
||||
# 格式4:连续测量模式
|
||||
{
|
||||
"name": "连续测量模式 (0x0021)",
|
||||
"cmd": bytes([0xAA, 0x00, 0x00, 0x21, 0x00, 0x01, 0x00, 0x00, 0x22]),
|
||||
"desc": "文档2.3.14 连续测量"
|
||||
}
|
||||
]
|
||||
|
||||
def clear_buffer():
|
||||
"""清空串口缓冲区"""
|
||||
try:
|
||||
data = laser_uart.read(-1)
|
||||
if data:
|
||||
print(f"清空: {len(data)}字节")
|
||||
except:
|
||||
pass
|
||||
|
||||
def send_and_wait(cmd, name, wait_time=2000):
|
||||
"""发送命令并等待响应"""
|
||||
print(f"\n📤 发送: {name}")
|
||||
print(f" 命令: {cmd.hex()}")
|
||||
|
||||
clear_buffer()
|
||||
|
||||
try:
|
||||
laser_uart.write(cmd)
|
||||
print(f" 已发送 {len(cmd)} 字节")
|
||||
except Exception as e:
|
||||
print(f" ❌ 发送失败: {e}")
|
||||
return None
|
||||
|
||||
# 等待响应
|
||||
start_time = time.ticks_ms()
|
||||
response = b""
|
||||
|
||||
while time.ticks_ms() - start_time < wait_time:
|
||||
try:
|
||||
chunk = laser_uart.read(1)
|
||||
if chunk:
|
||||
response += chunk
|
||||
# 完整响应通常是9或13字节
|
||||
if len(response) >= 9:
|
||||
# 检查是否完整帧
|
||||
if response[0] in [0xAA, 0xEE]:
|
||||
if len(response) >= 13: # 测距完整响应
|
||||
break
|
||||
elif response[0] == 0xEE: # 错误响应
|
||||
break
|
||||
except:
|
||||
break
|
||||
|
||||
time.sleep_ms(10)
|
||||
|
||||
if response:
|
||||
print(f" 📥 响应: {response.hex()}")
|
||||
print(f" 长度: {len(response)} 字节")
|
||||
|
||||
# 解析错误码
|
||||
if response[0] == 0xEE and len(response) >= 9:
|
||||
err_code = (response[7] << 8) | response[8]
|
||||
error_mapping = {
|
||||
0x0000: "无错误",
|
||||
0x0001: "硬件错误",
|
||||
0x0002: "无输出数据",
|
||||
0x0003: "反射信号太弱",
|
||||
0x0004: "反射信号太强",
|
||||
0x0005: "温度太高(>40℃)",
|
||||
0x0006: "温度太低(<-10℃)",
|
||||
0x0007: "电源电压低(<2.5V)",
|
||||
0x0008: "超出量程",
|
||||
0x0009: "读通讯错误",
|
||||
0x000A: "写通讯错误",
|
||||
0x000B: "地址错误"
|
||||
}
|
||||
err_msg = error_mapping.get(err_code, f"未知错误: 0x{err_code:04X}")
|
||||
print(f" ❌ 模块错误: {err_msg}")
|
||||
else:
|
||||
print(" ⚠️ 无响应")
|
||||
|
||||
return response
|
||||
|
||||
def parse_distance_data(response):
|
||||
"""解析距离数据"""
|
||||
if not response or len(response) < 13:
|
||||
return None
|
||||
|
||||
if response[0] != 0xAA or response[3] not in [0x20, 0x21, 0x22]:
|
||||
return None
|
||||
|
||||
# 解析4字节BCD码
|
||||
bcd_bytes = response[6:10]
|
||||
distance_int = 0
|
||||
|
||||
for byte in bcd_bytes:
|
||||
high = (byte >> 4) & 0x0F
|
||||
low = byte & 0x0F
|
||||
|
||||
if high > 9 or low > 9:
|
||||
return None
|
||||
|
||||
distance_int = distance_int * 100 + high * 10 + low
|
||||
|
||||
distance_m = distance_int / 1000.0
|
||||
|
||||
# 信号质量
|
||||
signal = 0
|
||||
if len(response) >= 12:
|
||||
signal = (response[10] << 8) | response[11]
|
||||
|
||||
return {
|
||||
'meters': distance_m,
|
||||
'millimeters': distance_m * 1000,
|
||||
'signal': signal,
|
||||
'raw': response.hex()
|
||||
}
|
||||
|
||||
# ==================== 主测试 ====================
|
||||
print("\n" + "="*50)
|
||||
print("M01激光测距模块详细测试")
|
||||
print("="*50)
|
||||
|
||||
try:
|
||||
# 1. 测试基本连接
|
||||
print("\n1. 测试模块连接...")
|
||||
version_cmd = bytes([0xAA, 0x80, 0x00, 0x0A, 0x8A])
|
||||
resp = send_and_wait(version_cmd, "读取硬件版本")
|
||||
|
||||
if resp and resp[0] == 0xAA and resp[3] == 0x0A:
|
||||
print(f"✅ 模块正常,版本: {resp[6]:02X}{resp[7]:02X}")
|
||||
else:
|
||||
print("❌ 模块连接测试失败")
|
||||
exit(1)
|
||||
|
||||
# 2. 开启激光
|
||||
print("\n2. 开启激光...")
|
||||
resp = send_and_wait(LASER_ON_CMD, "开启激光", 1000)
|
||||
if resp and resp.hex() == "aa0001be00010001c1":
|
||||
print("✅ 激光已开启")
|
||||
|
||||
print(" 等待激光稳定...")
|
||||
time.sleep(2) # 重要等待时间
|
||||
|
||||
# 3. 尝试不同的测距命令
|
||||
print("\n3. 测试不同测距命令...")
|
||||
|
||||
for i, test_cmd in enumerate(TEST_COMMANDS):
|
||||
print(f"\n{'='*30}")
|
||||
print(f"测试 {i+1}: {test_cmd['name']}")
|
||||
print(f"{test_cmd['desc']}")
|
||||
print(f"{'='*30}")
|
||||
|
||||
resp = send_and_wait(test_cmd['cmd'], test_cmd['name'], 3000)
|
||||
|
||||
if resp:
|
||||
if resp[0] == 0xAA and len(resp) >= 13:
|
||||
result = parse_distance_data(resp)
|
||||
if result:
|
||||
print(f"✅ 测距成功!")
|
||||
print(f" 距离: {result['meters']:.3f} m")
|
||||
print(f" 距离: {result['millimeters']:.1f} mm")
|
||||
print(f" 信号质量: {result['signal']}")
|
||||
break
|
||||
else:
|
||||
print("❌ 无法解析距离数据")
|
||||
elif resp[0] == 0xEE:
|
||||
print("❌ 命令执行错误")
|
||||
else:
|
||||
print("❌ 无效响应格式")
|
||||
else:
|
||||
print("❌ 无响应")
|
||||
|
||||
time.sleep(1) # 命令间间隔
|
||||
|
||||
# 4. 关闭激光
|
||||
print("\n4. 关闭激光...")
|
||||
send_and_wait(LASER_OFF_CMD, "关闭激光", 1000)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("🏁 测试完成")
|
||||
print("="*50)
|
||||
|
||||
print("\n📋 测试总结:")
|
||||
print("1. 模块通信: ✅ 正常")
|
||||
print("2. 激光控制: ✅ 正常")
|
||||
print("3. 测距功能: ❌ 有问题")
|
||||
print("\n建议:")
|
||||
print("1. 检查激光是否实际发光(在暗处观察红点)")
|
||||
print("2. 确保测量目标在有效范围内(0.2-60米)")
|
||||
print("3. 确保目标有足够反射率(白色平面最佳)")
|
||||
print("4. 如果所有测距命令都返回ERR_ADDR,可能是固件版本问题")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n🛑 用户中断")
|
||||
laser_uart.write(LASER_OFF_CMD)
|
||||
print("✅ 已发送关闭指令")
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试出错: {e}")
|
||||
16
test/test_motor.py
Normal file
16
test/test_motor.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from maix import gpio, pinmap, time
|
||||
|
||||
|
||||
#设置引脚为输出
|
||||
led = gpio.GPIO("A25", gpio.Mode.OUT)
|
||||
#设置低电平
|
||||
led.value(0)
|
||||
|
||||
while 1:
|
||||
# time.sleep_ms(1000)
|
||||
#对该引脚的电平进行取反(原高-》现低)
|
||||
# led.toggle()
|
||||
led.value(1)
|
||||
#延时
|
||||
time.sleep_ms(5000)
|
||||
led.value(0)
|
||||
130
test/test_power.py
Normal file
130
test/test_power.py
Normal file
@@ -0,0 +1,130 @@
|
||||
#!/usr/bin/env python3
|
||||
# test_power_with_init.py
|
||||
|
||||
from maix import i2c, time
|
||||
import sys
|
||||
|
||||
# INA226 register addresses
|
||||
INA226_ADDR = 0x40
|
||||
REG_CONFIGURATION = 0x00
|
||||
REG_BUS_VOLTAGE = 0x02
|
||||
REG_CURRENT = 0x04
|
||||
REG_CALIBRATION = 0x05
|
||||
|
||||
# Configuration values
|
||||
CONFIG_VALUE = 0x4527 # Configuration: 16 averages, 1.1ms conversion time, continuous mode
|
||||
CALIBRATION_VALUE = 0x1400 # Calibration value
|
||||
|
||||
def write_register(bus, reg, value):
|
||||
"""Write to INA226 register"""
|
||||
data = [(value >> 8) & 0xFF, value & 0xFF]
|
||||
bus.writeto_mem(INA226_ADDR, reg, bytes(data))
|
||||
|
||||
def read_register(bus, reg):
|
||||
"""Read from INA226 register"""
|
||||
data = bus.readfrom_mem(INA226_ADDR, reg, 2)
|
||||
return (data[0] << 8) | data[1]
|
||||
|
||||
def init_ina226(bus):
|
||||
"""Initialize INA226 chip"""
|
||||
try:
|
||||
# Write configuration register
|
||||
write_register(bus, REG_CONFIGURATION, CONFIG_VALUE)
|
||||
time.sleep_ms(10)
|
||||
|
||||
# Write calibration register
|
||||
write_register(bus, REG_CALIBRATION, CALIBRATION_VALUE)
|
||||
time.sleep_ms(10)
|
||||
|
||||
# Verify configuration by reading it back
|
||||
config_read = read_register(bus, REG_CONFIGURATION)
|
||||
if config_read != CONFIG_VALUE:
|
||||
print(f" Warning: Config readback mismatch: 0x{config_read:04X} != 0x{CONFIG_VALUE:04X}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f" Init failed: {e}")
|
||||
return False
|
||||
|
||||
def read_voltage(bus):
|
||||
"""Read bus voltage"""
|
||||
raw = read_register(bus, REG_BUS_VOLTAGE)
|
||||
voltage = raw * 1.25 / 1000
|
||||
return voltage
|
||||
|
||||
def read_current(bus):
|
||||
"""Read current"""
|
||||
raw = read_register(bus, REG_CURRENT)
|
||||
# Handle signed value
|
||||
if raw & 0x8000:
|
||||
raw = raw - 0x10000
|
||||
current_lsb = 0.001 * CALIBRATION_VALUE / 4096
|
||||
current = raw * current_lsb * 1000 # mA
|
||||
return current
|
||||
|
||||
def test_i2c_bus(bus_num):
|
||||
"""Test a single I2C bus with full initialization"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Testing I2C Bus {bus_num}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
try:
|
||||
# Step 1: Initialize I2C bus
|
||||
print(f" 1. Initializing I2C bus...")
|
||||
bus = i2c.I2C(bus_num, i2c.Mode.MASTER)
|
||||
print(f" OK")
|
||||
|
||||
# Step 2: Initialize INA226
|
||||
print(f" 2. Initializing INA226...")
|
||||
if not init_ina226(bus):
|
||||
print(f" FAILED")
|
||||
return False
|
||||
print(f" OK")
|
||||
|
||||
# Step 3: Read voltage multiple times
|
||||
print(f" 3. Reading voltage...")
|
||||
for i in range(5):
|
||||
try:
|
||||
voltage = read_voltage(bus)
|
||||
current = read_current(bus)
|
||||
print(f" Read {i+1}: {voltage:.3f}V, {current:.1f}mA")
|
||||
time.sleep_ms(100)
|
||||
except Exception as e:
|
||||
print(f" Read {i+1} failed: {e}")
|
||||
|
||||
print(f" SUCCESS")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f" FAILED: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Test all I2C buses"""
|
||||
print("INA226 Test with Proper Initialization")
|
||||
print("=" * 60)
|
||||
|
||||
# Test buses in order of likelihood
|
||||
test_order = [5, 1, 3, 4, 0, 2]
|
||||
|
||||
success_buses = []
|
||||
|
||||
for bus_num in test_order:
|
||||
if test_i2c_bus(bus_num):
|
||||
success_buses.append(bus_num)
|
||||
# If we found a working bus, stop testing others
|
||||
break
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Summary:")
|
||||
print(f" Working buses: {success_buses}")
|
||||
if not success_buses:
|
||||
print(f" ERROR: No working I2C bus found!")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
186
time_sync.py
Normal file
186
time_sync.py
Normal file
@@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
时间同步模块
|
||||
从4G模块获取时间并同步到系统
|
||||
"""
|
||||
import re
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
import config
|
||||
# from logger_bak import get_logger
|
||||
from logger_manager import logger_manager
|
||||
|
||||
|
||||
def parse_4g_time(cclk_response, timezone_offset=8):
|
||||
"""
|
||||
解析 AT+CCLK? 返回的时间字符串,并转换为本地时间
|
||||
|
||||
Args:
|
||||
cclk_response: AT+CCLK? 的响应字符串
|
||||
timezone_offset: 时区偏移(小时),默认8(中国时区 UTC+8)
|
||||
|
||||
Returns:
|
||||
datetime 对象(已转换为本地时间),如果解析失败返回 None
|
||||
"""
|
||||
try:
|
||||
# 匹配格式: +CCLK: "YY/MM/DD,HH:MM:SS+TZ"
|
||||
# 时区单位是四分之一小时(quarters of an hour)
|
||||
match = re.search(r'\+CCLK:\s*"(\d{2})/(\d{2})/(\d{2}),(\d{2}):(\d{2}):(\d{2})([+-]\d{1,3})?"', cclk_response)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
yy, mm, dd, hh, MM, ss, tz_str = match.groups()
|
||||
|
||||
# 年份处理:26 -> 2026
|
||||
year = 2000 + int(yy)
|
||||
month = int(mm)
|
||||
day = int(dd)
|
||||
hour = int(hh)
|
||||
minute = int(MM)
|
||||
second = int(ss)
|
||||
|
||||
# 创建 UTC 时间的 datetime 对象
|
||||
dt_utc = datetime(year, month, day, hour, minute, second)
|
||||
|
||||
# 解析时区偏移(单位:四分之一小时)
|
||||
if tz_str:
|
||||
try:
|
||||
# 时区偏移值(四分之一小时)
|
||||
tz_quarters = int(tz_str)
|
||||
|
||||
# 转换为小时(除以4)
|
||||
tz_hours = tz_quarters / 4.0
|
||||
|
||||
logger = logger_manager.logger
|
||||
if logger:
|
||||
logger.info(f"[TIME] 时区偏移: {tz_str} (四分之一小时) = {tz_hours} 小时")
|
||||
|
||||
# 转换为本地时间
|
||||
dt_local = dt_utc + timedelta(hours=tz_hours)
|
||||
except ValueError:
|
||||
# 如果时区解析失败,使用默认值
|
||||
logger = logger_manager.logger
|
||||
if logger:
|
||||
logger.warning(f"[TIME] 时区解析失败: {tz_str},使用默认 UTC+{timezone_offset}")
|
||||
dt_local = dt_utc + timedelta(hours=timezone_offset)
|
||||
else:
|
||||
# 没有时区信息,使用默认值
|
||||
logger = logger_manager.logger
|
||||
if logger:
|
||||
logger.info(f"[TIME] 未找到时区信息,使用默认 UTC+{timezone_offset}")
|
||||
dt_local = dt_utc + timedelta(hours=timezone_offset)
|
||||
|
||||
logger = logger_manager.logger
|
||||
if logger:
|
||||
logger.info(f"[TIME] UTC时间: {dt_utc.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
logger.info(f"[TIME] 本地时间: {dt_local.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
return dt_local
|
||||
except Exception as e:
|
||||
logger = logger_manager.logger
|
||||
if logger:
|
||||
logger.error(f"[TIME] 解析时间失败: {e}, 响应: {cclk_response}")
|
||||
else:
|
||||
print(f"[TIME] 解析时间失败: {e}, 响应: {cclk_response}")
|
||||
return None
|
||||
|
||||
|
||||
def get_time_from_4g(timezone_offset=8):
|
||||
"""
|
||||
通过4G模块获取当前时间(已转换为本地时间)
|
||||
|
||||
Args:
|
||||
timezone_offset: 时区偏移(小时),默认8(中国时区)
|
||||
|
||||
Returns:
|
||||
datetime 对象(本地时间),如果获取失败返回 None
|
||||
"""
|
||||
try:
|
||||
# 发送 AT+CCLK? 命令(延迟导入避免循环依赖)
|
||||
from hardware import hardware_manager
|
||||
# 检查 at_client 是否已初始化
|
||||
if hardware_manager.at_client is None:
|
||||
logger = logger_manager.logger
|
||||
if logger:
|
||||
logger.warning("[TIME] ATClient 尚未初始化,无法获取4G时间")
|
||||
else:
|
||||
print("[TIME] ATClient 尚未初始化,无法获取4G时间")
|
||||
return None
|
||||
resp = hardware_manager.at_client.send("AT+CCLK?", "OK", 3000)
|
||||
|
||||
if not resp or "+CCLK:" not in resp:
|
||||
logger = logger_manager.logger
|
||||
if logger:
|
||||
logger.warning(f"[TIME] 未获取到时间响应: {resp}")
|
||||
else:
|
||||
print(f"[TIME] 未获取到时间响应: {resp}")
|
||||
return None
|
||||
|
||||
# 解析并转换时区
|
||||
dt = parse_4g_time(resp, timezone_offset)
|
||||
return dt
|
||||
except Exception as e:
|
||||
logger = logger_manager.logger
|
||||
if logger:
|
||||
logger.error(f"[TIME] 获取4G时间异常: {e}")
|
||||
else:
|
||||
print(f"[TIME] 获取4G时间异常: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def sync_system_time_from_4g(timezone_offset=8):
|
||||
"""
|
||||
从4G模块同步时间到系统
|
||||
|
||||
Args:
|
||||
timezone_offset: 时区偏移(小时),默认8(中国时区)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
dt = get_time_from_4g(timezone_offset)
|
||||
if not dt:
|
||||
return False
|
||||
|
||||
try:
|
||||
# 转换为系统 date 命令需要的格式
|
||||
time_str = dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# 设置系统时间
|
||||
cmd = f'date -s "{time_str}" 2>&1'
|
||||
result = os.system(cmd)
|
||||
|
||||
if result == 0:
|
||||
logger = logger_manager.logger
|
||||
if logger:
|
||||
logger.info(f"[TIME] 系统时间已设置为: {time_str}")
|
||||
else:
|
||||
print(f"[TIME] 系统时间已设置为: {time_str}")
|
||||
|
||||
# 可选:同步到硬件时钟
|
||||
try:
|
||||
os.system('hwclock -w 2>/dev/null')
|
||||
logger = logger_manager.logger
|
||||
if logger:
|
||||
logger.info("[TIME] 已同步到硬件时钟")
|
||||
except:
|
||||
pass
|
||||
|
||||
return True
|
||||
else:
|
||||
logger = logger_manager.logger
|
||||
if logger:
|
||||
logger.error(f"[TIME] 设置系统时间失败,退出码: {result}")
|
||||
else:
|
||||
print(f"[TIME] 设置系统时间失败,退出码: {result}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger = logger_manager.logger
|
||||
if logger:
|
||||
logger.error(f"[TIME] 同步系统时间异常: {e}")
|
||||
else:
|
||||
print(f"[TIME] 同步系统时间异常: {e}")
|
||||
return False
|
||||
|
||||
1005
train_yolo/synth_compose_yolo.py
Normal file
1005
train_yolo/synth_compose_yolo.py
Normal file
File diff suppressed because it is too large
Load Diff
343
train_yolo/test_stage2_black_yolo_device.py
Normal file
343
train_yolo/test_stage2_black_yolo_device.py
Normal file
@@ -0,0 +1,343 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Stage2 黑三角 YOLO —— 在 Maix 设备上用本地图片测试(与线上 target_roi_yolo.try_black_triangle_boxes_work 完全一致)。
|
||||
|
||||
不在 PC 上跑 NPU;需把脚本与 config / target_roi_yolo.py 同步到设备,并在设备上执行。
|
||||
|
||||
典型用法
|
||||
--------
|
||||
# 输入已是 Stage1 裁切(与你保存的 stage2_roi_*.jpg 一致)
|
||||
python test/test_stage2_black_yolo_device.py /root/phot/stage2_roi_xxx.jpg
|
||||
|
||||
# 输入为整幅相机图,手动给出 Stage1 环靶 ROI(与线上日志 ring全图=[rx0,ry0,rx1,ry1] 一致)
|
||||
python test/test_stage2_black_yolo_device.py /root/phot/full.jpg --roi 197,196,507,461
|
||||
|
||||
# 对比 native / letterbox 坐标映射(排查 contain 训练与推理对齐)
|
||||
python test/test_stage2_black_yolo_device.py ./crop.jpg --compare-coord
|
||||
|
||||
# 覆盖置信度、模型路径(仍读其余项自 config)
|
||||
python test/test_stage2_black_yolo_device.py ./crop.jpg --conf 0.25 -m /maixapp/apps/t11/model_270648.mud
|
||||
|
||||
# 只看 NPU 原始框(映射前):判断坐标是 ~224 网络空间还是归一化 0~1
|
||||
python test/test_stage2_black_yolo_device.py ./crop.jpg --conf 0.05 --dump-raw 15
|
||||
|
||||
依赖:MaixPy(maix.nn)、OpenCV(cv2)、numpy;项目根须在 sys.path(本脚本已插入上级目录)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if _ROOT not in sys.path:
|
||||
sys.path.insert(0, _ROOT)
|
||||
|
||||
|
||||
def _parse_roi(s: str) -> tuple[int, int, int, int]:
|
||||
parts = [p.strip() for p in s.replace(" ", "").split(",")]
|
||||
if len(parts) != 4:
|
||||
raise ValueError("ROI 需要 4 个整数:x0,y0,x1,y1")
|
||||
return tuple(int(x) for x in parts) # type: ignore[return-value]
|
||||
|
||||
|
||||
def _load_rgb_numpy(path: str) -> "object":
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
bgr = cv2.imread(path, cv2.IMREAD_COLOR)
|
||||
if bgr is None:
|
||||
raise FileNotFoundError(f"cv2.imread 失败: {path}")
|
||||
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||||
return np.ascontiguousarray(rgb, dtype=np.uint8)
|
||||
|
||||
|
||||
def _draw_boxes_on_crop(
|
||||
slab_rgb,
|
||||
boxes: list[tuple[int, int, int, int]],
|
||||
labels: list[str] | None = None,
|
||||
):
|
||||
"""slab_rgb: H×W×3 RGB uint8;boxes 为扩 margin 后的 Stage2 子框(与线上绿框一致)。"""
|
||||
import cv2
|
||||
|
||||
vis = slab_rgb.copy()
|
||||
bgr = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)
|
||||
rh, rw = bgr.shape[:2]
|
||||
for i, (bx0, by0, bx1, by1) in enumerate(boxes):
|
||||
x0, y0 = int(bx0), int(by0)
|
||||
x1, y1 = int(bx1) - 1, int(by1) - 1
|
||||
x1 = max(x0, min(x1, rw - 1))
|
||||
y1 = max(y0, min(y1, rh - 1))
|
||||
cv2.rectangle(bgr, (x0, y0), (x1, y1), (0, 255, 0), 2)
|
||||
tag = labels[i] if labels and i < len(labels) else f"s2_{i}"
|
||||
cv2.putText(
|
||||
bgr,
|
||||
tag,
|
||||
(x0, max(0, y0 - 4)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 255, 0),
|
||||
1,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
return bgr
|
||||
|
||||
|
||||
class _PrintLogger:
|
||||
def info(self, msg):
|
||||
print(msg)
|
||||
|
||||
def warning(self, msg):
|
||||
print(msg)
|
||||
|
||||
def error(self, msg):
|
||||
print(msg)
|
||||
|
||||
|
||||
def _run_once(yroi_mod, img_rgb, roi_xyxy, logger):
|
||||
boxes = yroi_mod.try_black_triangle_boxes_work(img_rgb, roi_xyxy, logger)
|
||||
rx0, ry0, rx1, ry1 = roi_xyxy
|
||||
slab = img_rgb[ry0:ry1, rx0:rx1].copy()
|
||||
return boxes, slab
|
||||
|
||||
|
||||
def _copy_dump_raw_rows(yroi_mod, objs):
|
||||
"""把 Maix detect 返回对象拷贝成基础类型,避免 native 对象跨下一次 detect 存活。"""
|
||||
rows = []
|
||||
for o in objs:
|
||||
cid = yroi_mod._det_obj_class_id(o)
|
||||
try:
|
||||
sc = float(getattr(o, "score", 0.0))
|
||||
except (TypeError, ValueError):
|
||||
sc = 0.0
|
||||
rows.append((cid, sc, float(o.x), float(o.y), float(o.w), float(o.h)))
|
||||
return rows
|
||||
|
||||
|
||||
def _dump_raw_and_hard_exit(det, yroi_mod, slab_for_det, rw_s, rh_s, net_w, net_h, conf_th, iou_th, limit):
|
||||
"""
|
||||
MaixPy 某些版本在 YOLO detect 返回对象正常析构时会 SIGSEGV/pure virtual。
|
||||
raw dump 是诊断路径,打印完成后硬退出,绕过 Python/native 析构链。
|
||||
"""
|
||||
from maix import image as maix_image
|
||||
|
||||
roi_maix = maix_image.cv2image(slab_for_det, False, False)
|
||||
raw = det.detect(roi_maix, conf_th=conf_th, iou_th=iou_th)
|
||||
objs = yroi_mod._normalize_objs(raw if raw is not None else [])
|
||||
dump_rows = _copy_dump_raw_rows(yroi_mod, objs)
|
||||
raw_count = len(dump_rows)
|
||||
print(
|
||||
f"[DUMP-RAW] slab={rw_s}×{rh_s} net={net_w}×{net_h} "
|
||||
f"conf={conf_th} iou={iou_th} → NMS 后 raw 框数={raw_count}(与 coord_mode 无关)"
|
||||
)
|
||||
npr = min(int(limit), raw_count)
|
||||
for i in range(npr):
|
||||
cid, sc, x, y, ww, hh = dump_rows[i]
|
||||
print(f" #{i} cls={cid} score={sc:.4f} xywh=({x:.3f},{y:.3f},{ww:.3f},{hh:.3f})")
|
||||
if dump_rows:
|
||||
xs = [r[2] for r in dump_rows]
|
||||
ws = [r[4] for r in dump_rows]
|
||||
print(
|
||||
f"[DUMP-RAW] hint: x 范围≈[{min(xs):.2f},{max(xs):.2f}] "
|
||||
f"w 范围≈[{min(ws):.2f},{max(ws):.2f}] — "
|
||||
f"若整体在 0~{net_w} 量级多为网络画布坐标→应用 letterbox;"
|
||||
f"若 x,w 多在 0~1→可能是归一化,需在代码里乘 net 尺寸"
|
||||
)
|
||||
print("[INFO] --dump-raw 已完成;为规避 MaixPy YOLO native 析构崩溃,测试进程将直接退出。")
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
os._exit(0)
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(
|
||||
description="Stage2 黑三角 YOLO 设备本地图测试",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
ap.add_argument("image", help="本地图片路径(设备上的路径)")
|
||||
ap.add_argument(
|
||||
"--roi",
|
||||
default="",
|
||||
metavar="x0,y0,x1,y1",
|
||||
help="可选。若填写:image 为整幅图,在此图上取 Stage1 ROI 再跑 Stage2;"
|
||||
"留空:image 本身就是 Stage1 裁切图(默认)",
|
||||
)
|
||||
ap.add_argument("-o", "--output", default="", help="输出可视化路径;默认 原名_stage2_vis.jpg")
|
||||
ap.add_argument("-m", "--model", default="", help="覆盖 config.TRIANGLE_BLACK_YOLO_MODEL_PATH")
|
||||
ap.add_argument("--conf", type=float, default=None, help="覆盖 TRIANGLE_BLACK_YOLO_CONF_TH")
|
||||
ap.add_argument("--iou", type=float, default=None, help="覆盖 TRIANGLE_BLACK_YOLO_IOU_TH")
|
||||
ap.add_argument(
|
||||
"--coord",
|
||||
choices=["native", "letterbox"],
|
||||
default="",
|
||||
help="覆盖 TRIANGLE_BLACK_YOLO_COORD_MODE;默认用 config",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--compare-coord",
|
||||
action="store_true",
|
||||
help="各跑一次 native 与 letterbox,输出两张图 *_stage2_native.jpg / *_stage2_letterbox.jpg",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--fresh-detector",
|
||||
action="store_true",
|
||||
help="清掉 YOLO 缓存再测(换模型或排查缓存时用)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--allow-save-roi",
|
||||
action="store_true",
|
||||
help="不强制关闭 TRIANGLE_BLACK_YOLO_SAVE_ROI_CROP(默认测试时会关掉以免写满相册目录)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--dump-raw",
|
||||
type=int,
|
||||
default=0,
|
||||
metavar="N",
|
||||
help="打印前 N 个 detect 原始框 x,y,w,h,score,cls(coord 映射前;native/letterbox 共用同一批 raw)",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
img_path = os.path.abspath(args.image)
|
||||
if not os.path.isfile(img_path):
|
||||
print(f"[ERR] 找不到图片: {img_path}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import config as cfg
|
||||
import target_roi_yolo as yroi
|
||||
except ImportError as e:
|
||||
print(f"[ERR] 无法导入 config / target_roi_yolo: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.fresh_detector:
|
||||
yroi.reset_yolo_detector_cache()
|
||||
|
||||
# 备份并临时覆盖 config(单进程顺序跑)
|
||||
bak: dict[str, object] = {}
|
||||
|
||||
def _patch(key: str, val: object):
|
||||
if key not in bak:
|
||||
bak[key] = getattr(cfg, key, None)
|
||||
setattr(cfg, key, val)
|
||||
|
||||
def _restore():
|
||||
for k, v in bak.items():
|
||||
setattr(cfg, k, v)
|
||||
|
||||
try:
|
||||
_patch("TRIANGLE_BLACK_YOLO_ENABLE", True)
|
||||
if not args.allow_save_roi:
|
||||
_patch("TRIANGLE_BLACK_YOLO_SAVE_ROI_CROP", False)
|
||||
if args.model.strip():
|
||||
_patch("TRIANGLE_BLACK_YOLO_MODEL_PATH", args.model.strip())
|
||||
if args.conf is not None:
|
||||
_patch("TRIANGLE_BLACK_YOLO_CONF_TH", float(args.conf))
|
||||
if args.iou is not None:
|
||||
_patch("TRIANGLE_BLACK_YOLO_IOU_TH", float(args.iou))
|
||||
if args.coord and not args.compare_coord:
|
||||
_patch("TRIANGLE_BLACK_YOLO_COORD_MODE", args.coord)
|
||||
|
||||
mp = getattr(cfg, "TRIANGLE_BLACK_YOLO_MODEL_PATH", "") or ""
|
||||
if not os.path.isfile(mp):
|
||||
print(f"[ERR] 模型文件不存在: {mp}")
|
||||
sys.exit(1)
|
||||
|
||||
img_rgb = _load_rgb_numpy(img_path)
|
||||
h, w = int(img_rgb.shape[0]), int(img_rgb.shape[1])
|
||||
|
||||
if args.roi.strip():
|
||||
roi_xyxy = _parse_roi(args.roi.strip())
|
||||
rx0, ry0, rx1, ry1 = [int(round(float(v))) for v in roi_xyxy]
|
||||
if rx1 <= rx0 or ry1 <= ry0:
|
||||
print("[ERR] ROI 无效:需满足 x1>x0 且 y1>y0")
|
||||
sys.exit(1)
|
||||
# 与 target_roi_yolo.try_black_triangle_boxes_work 相同的 clip
|
||||
rx0 = max(0, min(rx0, w - 1))
|
||||
ry0 = max(0, min(ry0, h - 1))
|
||||
rx1 = max(rx0 + 1, min(rx1, w))
|
||||
ry1 = max(ry0 + 1, min(ry1, h))
|
||||
ring_roi = (rx0, ry0, rx1, ry1)
|
||||
print(f"[INFO] 模式=整图+ROI ring={ring_roi} image={w}×{h}")
|
||||
else:
|
||||
ring_roi = (0, 0, w, h)
|
||||
print(f"[INFO] 模式=已是 Stage1 裁切 crop={w}×{h}")
|
||||
|
||||
logger = _PrintLogger()
|
||||
det = yroi._get_detector(mp)
|
||||
if det is None:
|
||||
print("[ERR] 无法加载 nn.YOLOv5(检查模型路径与 Maix 环境)")
|
||||
sys.exit(1)
|
||||
net_w = int(det.input_width())
|
||||
net_h = int(det.input_height())
|
||||
print(f"[INFO] model={mp} net_in={net_w}×{net_h}")
|
||||
|
||||
rx0, ry0, rx1, ry1 = ring_roi
|
||||
import numpy as np
|
||||
|
||||
slab_for_det = np.ascontiguousarray(img_rgb[ry0:ry1, rx0:rx1], dtype=np.uint8).copy()
|
||||
rh_s, rw_s = int(slab_for_det.shape[0]), int(slab_for_det.shape[1])
|
||||
|
||||
modes = ["native", "letterbox"] if args.compare_coord else [
|
||||
(args.coord or getattr(cfg, "TRIANGLE_BLACK_YOLO_COORD_MODE", "native"))
|
||||
]
|
||||
|
||||
base, ext = os.path.splitext(img_path)
|
||||
ext = ext if ext else ".jpg"
|
||||
|
||||
for mode in modes:
|
||||
_patch("TRIANGLE_BLACK_YOLO_COORD_MODE", mode)
|
||||
cur_coord = getattr(cfg, "TRIANGLE_BLACK_YOLO_COORD_MODE", mode)
|
||||
print(f"[INFO] --- TRIANGLE_BLACK_YOLO_COORD_MODE={cur_coord} ---")
|
||||
|
||||
boxes, slab = _run_once(yroi, img_rgb, ring_roi, logger)
|
||||
print(
|
||||
f"[INFO] 子框数量={len(boxes)} conf={getattr(cfg, 'TRIANGLE_BLACK_YOLO_CONF_TH', '?')} "
|
||||
f"coord={cur_coord}"
|
||||
)
|
||||
for i, b in enumerate(boxes):
|
||||
print(f" s2_{i}: {b}")
|
||||
|
||||
if args.compare_coord:
|
||||
out_path = f"{base}_stage2_{mode}{ext}"
|
||||
elif args.output.strip():
|
||||
out_path = args.output.strip()
|
||||
else:
|
||||
out_path = base + "_stage2_vis" + ext
|
||||
|
||||
import cv2
|
||||
|
||||
bgr = _draw_boxes_on_crop(slab, boxes)
|
||||
cv2.imwrite(out_path, bgr, [int(cv2.IMWRITE_JPEG_QUALITY), 92])
|
||||
print(f"[OK] saved: {out_path}")
|
||||
|
||||
if args.compare_coord:
|
||||
print(
|
||||
"[HINT] contain 训练时若 letterbox 对齐更好,请将 config 里 "
|
||||
"TRIANGLE_BLACK_YOLO_COORD_MODE 设为 letterbox"
|
||||
)
|
||||
|
||||
if args.dump_raw > 0:
|
||||
conf_th = float(getattr(cfg, "TRIANGLE_BLACK_YOLO_CONF_TH", 0.5))
|
||||
iou_th = float(getattr(cfg, "TRIANGLE_BLACK_YOLO_IOU_TH", 0.45))
|
||||
print("\n[INFO] --dump-raw 放在最后执行,避免 raw native 对象影响 compare-coord 流程。")
|
||||
_dump_raw_and_hard_exit(
|
||||
det,
|
||||
yroi,
|
||||
slab_for_det,
|
||||
rw_s,
|
||||
rh_s,
|
||||
net_w,
|
||||
net_h,
|
||||
conf_th,
|
||||
iou_th,
|
||||
args.dump_raw,
|
||||
)
|
||||
|
||||
finally:
|
||||
_restore()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
242
train_yolo/test_triangle_one_image.py
Normal file
242
train_yolo/test_triangle_one_image.py
Normal file
@@ -0,0 +1,242 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
单张图片快速测试:三角形四角标记识别 + 单应性落点 + PnP 估距
|
||||
|
||||
用法(在板子上):
|
||||
python3 test/test_triangle_one_image.py --image /root/phot/xxx.jpg --out /root/phot/tri_out.jpg
|
||||
|
||||
调参对比(不改代码,临时覆盖 config.TRIANGLE_*):
|
||||
python3 test/test_triangle_one_image.py --image /root/phot/xxx.jpg --preset shadow
|
||||
python3 test/test_triangle_one_image.py --image /root/phot/xxx.jpg --max-interior-gray 160 --min-dark-ratio 0.20
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import config
|
||||
import triangle_target as tri_mod
|
||||
from triangle_target import (
|
||||
detect_triangle_markers,
|
||||
load_camera_from_xml,
|
||||
load_triangle_positions,
|
||||
try_triangle_scoring,
|
||||
)
|
||||
|
||||
|
||||
def _apply_overrides(args) -> None:
|
||||
# 预设:阴影/低对比度场景更宽松(尽量保持速度:不启 CLAHE)
|
||||
if args.preset == "shadow":
|
||||
setattr(config, "TRIANGLE_ENABLE_CLAHE_FALLBACK", False)
|
||||
setattr(config, "TRIANGLE_MIN_CONTRAST_DIFF", 0)
|
||||
setattr(config, "TRIANGLE_MAX_INTERIOR_GRAY", 160)
|
||||
setattr(config, "TRIANGLE_DARK_PIXEL_GRAY", 160)
|
||||
setattr(config, "TRIANGLE_MIN_DARK_RATIO", 0.20)
|
||||
# adaptive 只在 Otsu 失败时尝试,保持尝试次数很少
|
||||
setattr(config, "TRIANGLE_ADAPTIVE_BLOCK_SIZES", (21,))
|
||||
|
||||
# 手动覆盖(优先级高于 preset)
|
||||
if args.max_interior_gray is not None:
|
||||
setattr(config, "TRIANGLE_MAX_INTERIOR_GRAY", int(args.max_interior_gray))
|
||||
if args.dark_pixel_gray is not None:
|
||||
setattr(config, "TRIANGLE_DARK_PIXEL_GRAY", int(args.dark_pixel_gray))
|
||||
if args.min_dark_ratio is not None:
|
||||
setattr(config, "TRIANGLE_MIN_DARK_RATIO", float(args.min_dark_ratio))
|
||||
if args.min_contrast_diff is not None:
|
||||
setattr(config, "TRIANGLE_MIN_CONTRAST_DIFF", int(args.min_contrast_diff))
|
||||
if args.detect_scale is not None:
|
||||
setattr(config, "TRIANGLE_DETECT_SCALE", float(args.detect_scale))
|
||||
if args.adaptive_blocks is not None:
|
||||
bs = tuple(int(x) for x in args.adaptive_blocks.split(",") if x.strip())
|
||||
setattr(config, "TRIANGLE_ADAPTIVE_BLOCK_SIZES", bs)
|
||||
|
||||
|
||||
def _dump_config() -> Dict[str, Any]:
|
||||
keys = [
|
||||
"TRIANGLE_DETECT_SCALE",
|
||||
"TRIANGLE_SIZE_RANGE",
|
||||
"TRIANGLE_MAX_INTERIOR_GRAY",
|
||||
"TRIANGLE_DARK_PIXEL_GRAY",
|
||||
"TRIANGLE_MIN_DARK_RATIO",
|
||||
"TRIANGLE_MIN_CONTRAST_DIFF",
|
||||
"TRIANGLE_ADAPTIVE_BLOCK_SIZES",
|
||||
"TRIANGLE_MAX_FILTERED_FOR_COMBO",
|
||||
"TRIANGLE_EARLY_EXIT_CANDIDATES",
|
||||
"TRIANGLE_ENABLE_CLAHE_FALLBACK",
|
||||
]
|
||||
out = {}
|
||||
for k in keys:
|
||||
out[k] = getattr(config, k, None)
|
||||
return out
|
||||
|
||||
|
||||
def _draw_tri_debug(img_bgr: np.ndarray, tri: Dict[str, Any]) -> np.ndarray:
|
||||
out = img_bgr.copy()
|
||||
markers = tri.get("markers") or []
|
||||
|
||||
# 画三角形轮廓 + center + id
|
||||
for m in markers:
|
||||
corners = np.array(m.get("corners", []), dtype=np.int32)
|
||||
if corners.size == 0:
|
||||
continue
|
||||
cv2.polylines(out, [corners], True, (0, 255, 0), 2)
|
||||
c = m.get("center") or (corners[:, 0].mean(), corners[:, 1].mean())
|
||||
cx, cy = int(c[0]), int(c[1])
|
||||
cv2.circle(out, (cx, cy), 4, (0, 0, 255), -1)
|
||||
mid = m.get("id", "?")
|
||||
cv2.putText(out, f"T{mid}", (cx - 18, cy - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 0), 1)
|
||||
|
||||
# 若有 homography,画靶心(把 (0,0) 反投影到图像)
|
||||
H = tri.get("homography")
|
||||
if H is not None:
|
||||
try:
|
||||
H = np.array(H, dtype=np.float64)
|
||||
H_inv = np.linalg.inv(H)
|
||||
c_img = cv2.perspectiveTransform(np.array([[[0.0, 0.0]]], dtype=np.float32), H_inv)[0][0]
|
||||
ocx, ocy = int(c_img[0]), int(c_img[1])
|
||||
cv2.circle(out, (ocx, ocy), 5, (0, 0, 255), -1)
|
||||
cv2.circle(out, (ocx, ocy), 10, (0, 0, 255), 1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 叠加结果信息
|
||||
lines = []
|
||||
if tri.get("ok"):
|
||||
lines.append("tri_ok=True")
|
||||
if tri.get("dx_cm") is not None and tri.get("dy_cm") is not None:
|
||||
lines.append(f"dx,dy=({tri['dx_cm']:.2f},{tri['dy_cm']:.2f})cm")
|
||||
if tri.get("distance_m") is not None:
|
||||
lines.append(f"dist={float(tri['distance_m']):.2f}m")
|
||||
else:
|
||||
lines.append("tri_ok=False")
|
||||
|
||||
y0 = 22
|
||||
for i, t in enumerate(lines):
|
||||
cv2.putText(out, t, (10, y0 + i * 18), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 0), 1)
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--image", required=True, help="输入图片路径(jpg/png)")
|
||||
ap.add_argument("--out", default="", help="输出标注图片路径(可选)")
|
||||
ap.add_argument("--laser-x", type=int, default=-1, help="激光点 x(像素),默认用图像中心")
|
||||
ap.add_argument("--laser-y", type=int, default=-1, help="激光点 y(像素),默认用图像中心")
|
||||
ap.add_argument("--preset", choices=["", "shadow"], default="", help="调参预设(shadow=阴影更鲁棒,不启 CLAHE)")
|
||||
ap.add_argument("--max-interior-gray", type=int, default=None)
|
||||
ap.add_argument("--dark-pixel-gray", type=int, default=None)
|
||||
ap.add_argument("--min-dark-ratio", type=float, default=None)
|
||||
ap.add_argument("--min-contrast-diff", type=int, default=None)
|
||||
ap.add_argument("--detect-scale", type=float, default=None)
|
||||
ap.add_argument("--adaptive-blocks", default=None, help="例如: 11,21 ;为空表示不改")
|
||||
ap.add_argument("--verbose", action="store_true", help="输出更多检测阶段信息")
|
||||
args = ap.parse_args()
|
||||
|
||||
_apply_overrides(args)
|
||||
# triangle_target.py 的日志默认写到 logger_manager;在离线脚本里 logger 可能未初始化。
|
||||
# verbose 模式下把 _log 重定向为 print,方便直接看到诊断信息。
|
||||
if args.verbose:
|
||||
try:
|
||||
tri_mod._log = lambda msg: print(str(msg))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
img_bgr = cv2.imread(args.image, cv2.IMREAD_COLOR)
|
||||
if img_bgr is None:
|
||||
raise SystemExit(f"读图失败:{args.image}")
|
||||
# triangle_target.try_triangle_scoring 约定输入为 RGB;OpenCV imread 返回 BGR
|
||||
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||
|
||||
h, w = img_bgr.shape[:2]
|
||||
if args.laser_x >= 0 and args.laser_y >= 0:
|
||||
laser_point = (int(args.laser_x), int(args.laser_y))
|
||||
else:
|
||||
laser_point = (w // 2, h // 2)
|
||||
|
||||
K, dist = load_camera_from_xml(getattr(config, "CAMERA_CALIB_XML", ""))
|
||||
pos = load_triangle_positions(getattr(config, "TRIANGLE_POSITIONS_JSON", ""))
|
||||
|
||||
print("[tri-test] image:", args.image, "shape:", (h, w))
|
||||
print("[tri-test] laser_point:", laser_point)
|
||||
print("[tri-test] calib_ok:", bool(K is not None and dist is not None), "pos_ok:", bool(pos))
|
||||
print("[tri-test] config:", json.dumps(_dump_config(), ensure_ascii=False))
|
||||
|
||||
# 先单独跑一次三角形候选检测,便于区分“没找到候选” vs “找到候选但评分/单应性失败”
|
||||
scale = float(getattr(config, "TRIANGLE_DETECT_SCALE", 0.5) or 0.5)
|
||||
if not (0.05 <= scale <= 1.0):
|
||||
scale = 0.5
|
||||
long_side = max(h, w)
|
||||
max_dim = max(64, int(long_side * scale))
|
||||
if long_side > max_dim:
|
||||
det_scale = max_dim / long_side
|
||||
det_w = int(w * det_scale)
|
||||
det_h = int(h * det_scale)
|
||||
img_det = cv2.resize(img_bgr, (det_w, det_h), interpolation=cv2.INTER_LINEAR)
|
||||
inv_scale = 1.0 / det_scale
|
||||
size_range_det = (
|
||||
max(4, int(getattr(config, "TRIANGLE_SIZE_RANGE", (8, 500))[0] * det_scale)),
|
||||
max(8, int(getattr(config, "TRIANGLE_SIZE_RANGE", (8, 500))[1] * det_scale)),
|
||||
)
|
||||
else:
|
||||
img_det = img_bgr
|
||||
inv_scale = 1.0
|
||||
size_range_det = getattr(config, "TRIANGLE_SIZE_RANGE", (8, 500))
|
||||
|
||||
gray = cv2.cvtColor(img_det, cv2.COLOR_BGR2GRAY)
|
||||
markers_det = detect_triangle_markers(
|
||||
gray,
|
||||
orig_gray=gray,
|
||||
size_range=size_range_det,
|
||||
verbose=bool(args.verbose),
|
||||
)
|
||||
if inv_scale != 1.0 and markers_det:
|
||||
for m in markers_det:
|
||||
m["center"] = [m["center"][0] * inv_scale, m["center"][1] * inv_scale]
|
||||
m["corners"] = [[c[0] * inv_scale, c[1] * inv_scale] for c in m["corners"]]
|
||||
|
||||
print("[tri-test] markers_found:", len(markers_det), "ids:", [m.get("id") for m in markers_det])
|
||||
|
||||
t0 = time.time()
|
||||
tri = try_triangle_scoring(
|
||||
img_rgb, # try_triangle_scoring 期望 RGB
|
||||
laser_point,
|
||||
pos,
|
||||
K,
|
||||
dist,
|
||||
size_range=getattr(config, "TRIANGLE_SIZE_RANGE", (8, 500)),
|
||||
)
|
||||
dt_ms = int(round((time.time() - t0) * 1000))
|
||||
|
||||
print("[tri-test] elapsed_ms:", dt_ms)
|
||||
print(json.dumps(tri, ensure_ascii=False, indent=2))
|
||||
|
||||
if args.out:
|
||||
out_path = args.out
|
||||
# 允许传目录(如 ./),自动生成文件名;未带扩展名时默认 .jpg
|
||||
if out_path.endswith("/") or out_path.endswith("\\") or os.path.isdir(out_path):
|
||||
out_path = os.path.join(out_path, "tri_out.jpg")
|
||||
root, ext = os.path.splitext(out_path)
|
||||
if not ext:
|
||||
out_path = root + ".jpg"
|
||||
|
||||
# 若 try_triangle_scoring 失败且没带回 markers,至少把候选 markers 画出来,方便肉眼判断
|
||||
tri_for_draw = tri if isinstance(tri, dict) else {"ok": False}
|
||||
if not tri_for_draw.get("markers") and markers_det:
|
||||
tri_for_draw = dict(tri_for_draw)
|
||||
tri_for_draw["markers"] = markers_det
|
||||
out_img = _draw_tri_debug(img_bgr, tri_for_draw)
|
||||
ok = cv2.imwrite(out_path, out_img)
|
||||
if not ok:
|
||||
raise SystemExit(f"写图失败(可能是不支持的扩展名):{out_path}")
|
||||
print("[tri-test] wrote:", out_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
257
train_yolo/test_yolo_draw_boxes.py
Normal file
257
train_yolo/test_yolo_draw_boxes.py
Normal file
@@ -0,0 +1,257 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
本地图片 → Maix YOLOv5 检测 → 画框保存(用于核对坐标 mode / 多框 union)。
|
||||
|
||||
运行环境:MaixCAM / MaixPy(需 maix.image / maix.nn),在项目根或任意目录执行均可。
|
||||
|
||||
示例:
|
||||
python test/test_yolo_draw_boxes.py /root/phot/shot_xxx.jpg
|
||||
python test/test_yolo_draw_boxes.py shot.jpg --loader cv2_rgb --conf 0.25
|
||||
python test/test_yolo_draw_boxes.py shot.jpg --debug
|
||||
python -h # 查看 --loader / --debug / --union 等全部参数
|
||||
|
||||
脚本版本(与设备同步用):20260206-yolo-vis
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if _ROOT not in sys.path:
|
||||
sys.path.insert(0, _ROOT)
|
||||
|
||||
|
||||
def _load_maix_image(path: str, image_mod):
|
||||
"""maix.image.load(部分 JPEG 解码后与 camera.read() 像素布局不一致,可能导致 NPU 全空)。"""
|
||||
return image_mod.load(path)
|
||||
|
||||
|
||||
def _load_cv2_rgb_as_maix(path: str, image_mod):
|
||||
"""
|
||||
OpenCV 读盘为 BGR → 转 RGB → 与 shoot_manager 里 image2cv 逆过程一致,供 YOLO input type: rgb。
|
||||
"""
|
||||
import cv2
|
||||
|
||||
arr = cv2.imread(path, cv2.IMREAD_COLOR)
|
||||
if arr is None:
|
||||
raise FileNotFoundError(f"cv2.imread 失败: {path}")
|
||||
arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
|
||||
return image_mod.cv2image(arr, False, False)
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(
|
||||
description="YOLO 画框测试(Maix)",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="若提示 unrecognized arguments: --debug,说明设备上脚本未更新,请同步仓库中的 test/test_yolo_draw_boxes.py",
|
||||
)
|
||||
ap.add_argument("image", help="输入图片路径")
|
||||
ap.add_argument("-o", "--output", default="", help="输出图片路径;默认 原名_yolo_vis.jpg")
|
||||
ap.add_argument("-m", "--model", default="", help="覆盖 config.TRIANGLE_YOLO_MODEL_PATH")
|
||||
ap.add_argument("--conf", type=float, default=None, help="置信度阈值")
|
||||
ap.add_argument("--iou", type=float, default=None, help="NMS IoU")
|
||||
ap.add_argument(
|
||||
"--coord",
|
||||
choices=["native", "letterbox"],
|
||||
default="",
|
||||
help="坐标映射;默认读 config.TRIANGLE_YOLO_COORD_MODE",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--union",
|
||||
action="store_true",
|
||||
help="按 TRIANGLE_YOLO_RING_CLASS_IDS 过滤后画合并外接矩形(与线上 ROI merge=union 一致)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--loader",
|
||||
choices=["auto", "maix", "cv2_rgb"],
|
||||
default="auto",
|
||||
help="auto: 先 maix.load,0 框则改用 cv2 RGB(推荐排查「有图但始终 0 框」)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="打印 detect 原始返回类型与 repr(截断)",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
try:
|
||||
from maix import image, nn
|
||||
except ImportError:
|
||||
print("[ERR] 需要 MaixPy(maix.image / maix.nn),请在 MaixCAM 上运行。")
|
||||
sys.exit(1)
|
||||
|
||||
import config as cfg
|
||||
import target_roi_yolo as yroi
|
||||
|
||||
img_path = os.path.abspath(args.image)
|
||||
if not os.path.isfile(img_path):
|
||||
print(f"[ERR] 找不到图片: {img_path}")
|
||||
sys.exit(1)
|
||||
|
||||
model_path = (args.model or getattr(cfg, "TRIANGLE_YOLO_MODEL_PATH", "") or "").strip()
|
||||
if not os.path.isfile(model_path):
|
||||
print(f"[ERR] 模型文件不存在: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
conf_th = (
|
||||
float(args.conf)
|
||||
if args.conf is not None
|
||||
else float(getattr(cfg, "TRIANGLE_YOLO_CONF_TH", 0.5))
|
||||
)
|
||||
iou_th = (
|
||||
float(args.iou)
|
||||
if args.iou is not None
|
||||
else float(getattr(cfg, "TRIANGLE_YOLO_IOU_TH", 0.45))
|
||||
)
|
||||
coord_mode = (args.coord or getattr(cfg, "TRIANGLE_YOLO_COORD_MODE", "native")).lower()
|
||||
|
||||
out_path = args.output.strip()
|
||||
if not out_path:
|
||||
base, ext = os.path.splitext(img_path)
|
||||
ext = ext if ext else ".jpg"
|
||||
out_path = base + "_yolo_vis" + ext
|
||||
|
||||
det = nn.YOLOv5(model=model_path, dual_buff=False)
|
||||
net_w = int(det.input_width())
|
||||
net_h = int(det.input_height())
|
||||
|
||||
def _run_detect(maix_img, tag: str):
|
||||
r = det.detect(maix_img, conf_th=conf_th, iou_th=iou_th)
|
||||
if args.debug:
|
||||
rlen = len(r) if r is not None and hasattr(r, "__len__") else "n/a"
|
||||
rrepr = repr(r)
|
||||
if len(rrepr) > 300:
|
||||
rrepr = rrepr[:300] + "..."
|
||||
print(f"[DEBUG] loader={tag} raw_type={type(r)} len={rlen} repr={rrepr}")
|
||||
return yroi._normalize_objs(r if r is not None else []), maix_img, tag
|
||||
|
||||
img = None
|
||||
load_tag = ""
|
||||
objs = []
|
||||
|
||||
if args.loader == "cv2_rgb":
|
||||
img = _load_cv2_rgb_as_maix(img_path, image)
|
||||
load_tag = "cv2_rgb"
|
||||
objs, img, load_tag = _run_detect(img, load_tag)
|
||||
elif args.loader == "maix":
|
||||
img = _load_maix_image(img_path, image)
|
||||
load_tag = "maix_load"
|
||||
objs, img, load_tag = _run_detect(img, load_tag)
|
||||
else:
|
||||
# auto
|
||||
img = _load_maix_image(img_path, image)
|
||||
load_tag = "maix_load"
|
||||
objs, img, load_tag = _run_detect(img, load_tag)
|
||||
if len(objs) == 0:
|
||||
print(
|
||||
"[WARN] maix.image.load 在 conf_th=%s 下仍为 0 框,改用 cv2 BGR→RGB→cv2image 重试(常见可恢复)"
|
||||
% conf_th
|
||||
)
|
||||
img2 = _load_cv2_rgb_as_maix(img_path, image)
|
||||
objs, img, load_tag = _run_detect(img2, "cv2_rgb_retry")
|
||||
|
||||
src_w, src_h = img.width(), img.height()
|
||||
|
||||
labels = getattr(det, "labels", None)
|
||||
|
||||
def _label(cid: int) -> str:
|
||||
if labels is None:
|
||||
return str(cid)
|
||||
try:
|
||||
return str(labels[int(cid)])
|
||||
except Exception:
|
||||
return str(cid)
|
||||
|
||||
print(
|
||||
f"[INFO] loader={load_tag} image={src_w}×{src_h}, net_in={net_w}×{net_h}, "
|
||||
f"coord={coord_mode}, conf_th={conf_th}, iou_th={iou_th}"
|
||||
)
|
||||
print(f"[INFO] NMS 后检测框数量={len(objs)} → {out_path}")
|
||||
if len(objs) == 0:
|
||||
print(
|
||||
"[HINT] 仍为 0 框时常见原因:\n"
|
||||
" 1) 强制 cv2 路径: --loader cv2_rgb\n"
|
||||
" 2) NMS 过严: --iou 0.95\n"
|
||||
" 3) 图与训练分布差太大 / 模型未见过该场景\n"
|
||||
" 4) 用 camera.read() 一帧存盘再测,对比 file 与实时是否一致"
|
||||
)
|
||||
|
||||
# 颜色:按类别轮换(仅有 COLOR_* 时常量时用)
|
||||
color_cycle = []
|
||||
for name in ("RED", "GREEN", "BLUE", "ORANGE", "YELLOW", "CYAN", "MAGENTA"):
|
||||
c = getattr(image, f"COLOR_{name}", None)
|
||||
if c is not None:
|
||||
color_cycle.append(c)
|
||||
if not color_cycle:
|
||||
color_cycle = [getattr(image, "COLOR_RED", 0)]
|
||||
|
||||
for i, o in enumerate(objs):
|
||||
cid = yroi._det_obj_class_id(o)
|
||||
if cid is None:
|
||||
cid = -1
|
||||
try:
|
||||
sc = float(o.score)
|
||||
except Exception:
|
||||
sc = 0.0
|
||||
x0, y0, x1, y1 = yroi._det_to_src_xyxy(o, coord_mode, src_w, src_h, net_w, net_h)
|
||||
ix = int(max(0, min(x0, src_w - 1)))
|
||||
iy = int(max(0, min(y0, src_h - 1)))
|
||||
iw = int(max(1, min(x1 - x0, src_w - ix)))
|
||||
ih = int(max(1, min(y1 - y0, src_h - iy)))
|
||||
col = color_cycle[cid % len(color_cycle)] if cid >= 0 else color_cycle[0]
|
||||
img.draw_rect(ix, iy, iw, ih, color=col)
|
||||
ty = max(0, iy - 14)
|
||||
msg = f"{_label(cid)} {sc:.2f}"
|
||||
img.draw_string(ix, ty, msg, color=col)
|
||||
print(f" #{i} cls={cid} {_label(cid)} score={sc:.3f} xywh=({ix},{iy},{iw},{ih})")
|
||||
|
||||
if args.union:
|
||||
class_ids = getattr(cfg, "TRIANGLE_YOLO_RING_CLASS_IDS", (0,))
|
||||
if isinstance(class_ids, int):
|
||||
class_ids = (class_ids,)
|
||||
cand = [o for o in objs if yroi._det_obj_class_id(o) in class_ids]
|
||||
if cand:
|
||||
xy_list = [
|
||||
yroi._det_to_src_xyxy(o, coord_mode, src_w, src_h, net_w, net_h) for o in cand
|
||||
]
|
||||
merged = yroi._merge_roi_xyxy(xy_list, "union")
|
||||
if merged:
|
||||
mx0, my0, mx1, my1 = merged
|
||||
mx0 = max(0, min(mx0, src_w - 1))
|
||||
my0 = max(0, min(my0, src_h - 1))
|
||||
mx1 = max(mx0 + 1, min(mx1, src_w))
|
||||
my1 = max(my0 + 1, min(my1, src_h))
|
||||
uw, uh = int(mx1 - mx0), int(my1 - my0)
|
||||
ucol = getattr(image, "COLOR_GREEN", color_cycle[0])
|
||||
# 画粗一点的 union:描两遍错位矩形简易模拟加粗
|
||||
for d in (0, 2):
|
||||
img.draw_rect(
|
||||
int(mx0) - d,
|
||||
int(my0) - d,
|
||||
uw + 2 * d,
|
||||
uh + 2 * d,
|
||||
color=ucol,
|
||||
)
|
||||
img.draw_string(
|
||||
int(mx0),
|
||||
max(0, int(my0) - 28),
|
||||
f"UNION ({len(cand)} boxes)",
|
||||
color=ucol,
|
||||
)
|
||||
print(f"[INFO] UNION [{int(mx0)},{int(my0)},{int(mx1)},{int(my1)}] from {len(cand)} boxes")
|
||||
else:
|
||||
print("[WARN] --union 但 RING_CLASS_IDS 过滤后无框")
|
||||
|
||||
try:
|
||||
img.save(out_path, quality=95)
|
||||
except TypeError:
|
||||
img.save(out_path)
|
||||
print(f"[OK] saved: {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
506
train_yolo/train_black_triangle_pos.py
Normal file
506
train_yolo/train_black_triangle_pos.py
Normal file
@@ -0,0 +1,506 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
YOLO11 关键点检测训练脚本(靶纸四角)。
|
||||
|
||||
设备优先级(--device auto):Intel XPU > NVIDIA CUDA > CPU。
|
||||
默认 imgsz=960;批大小默认 4(大图显存紧张时可再降)。
|
||||
|
||||
关于「业务像素误差」:
|
||||
Ultralytics 没有在 yaml 里设定「像素阈值」的选项;反向传播仍由 pose/kobj/box 等内部 loss 驱动。
|
||||
- 监控:--pixel-metrics-every N(每 N 个 epoch 打印 mean/p95,并合并进 runs/.../results.csv,见 pose_pixel_metrics.py)。
|
||||
- 选 best.pt / early stopping:加 --best-by-pixel,用验证集 mean 像素误差(与 pose_pixel_metrics
|
||||
同一口径)代替 mAP 合成 fitness(fitness = -mean_px,越小越好)。
|
||||
多卡 DDP(world_size>1)时会自动退回默认 mAP fitness。
|
||||
|
||||
XPU:Ultralytics BaseTrainer._get_memory / _clear_memory 把非 MPS、非 CPU 一律当 CUDA,
|
||||
会在验证前调用 torch.cuda 而报错;本脚本在选用 XPU 时自动打补丁(见 _patch_ultralytics_trainer_for_xpu)。
|
||||
|
||||
务必使用 pose 任务:YOLO(...) 与 model.train(...) 均指定 task='pose'。若误用默认 detect,
|
||||
会把 17 列 Pose 标注当成检测/分割解析,校验时出现「coordinates > 1」或 [2.] 等假象。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import gc
|
||||
import glob
|
||||
import os
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
|
||||
from pose_pixel_metrics import eval_val_pixel_error
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore',
|
||||
message=".*scatter_add_kernel does not have a deterministic implementation.*")
|
||||
|
||||
|
||||
def _clear_ultralytics_label_caches(data_yaml_path: str) -> int:
|
||||
"""删除 data.yaml 的 path 下 labels/*.cache。
|
||||
|
||||
Ultralytics 的校验缓存 hash 仅依赖「标签/图片路径字符串 + 各文件 size 之和」,不含文件内容;
|
||||
修正 *.txt 后若总和巧合不变,可能继续加载旧 cache 并重播旧的 corrupt 日志,训练前应删掉。"""
|
||||
from ultralytics.utils import YAML
|
||||
|
||||
try:
|
||||
cfg = YAML.load(data_yaml_path)
|
||||
except Exception:
|
||||
return 0
|
||||
root = cfg.get("path")
|
||||
if not root:
|
||||
return 0
|
||||
root = os.path.abspath(os.path.expanduser(str(root)))
|
||||
pattern = os.path.join(root, "labels", "*.cache")
|
||||
n = 0
|
||||
for p in glob.glob(pattern):
|
||||
try:
|
||||
os.unlink(p)
|
||||
n += 1
|
||||
except OSError:
|
||||
pass
|
||||
return n
|
||||
|
||||
|
||||
def _pick_device(explicit: str | None):
|
||||
"""返回 ultralytics train/predict 可用的 device。"""
|
||||
if explicit and explicit != "auto":
|
||||
e = explicit.lower()
|
||||
if e == "xpu":
|
||||
if getattr(torch, "xpu", None) is None or not torch.xpu.is_available():
|
||||
raise RuntimeError("指定了 --device xpu 但当前环境不可用")
|
||||
return torch.device("xpu")
|
||||
if e in ("0", "cuda", "gpu"):
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("指定了 CUDA 但不可用")
|
||||
return 0
|
||||
if e == "cpu":
|
||||
return "cpu"
|
||||
return explicit
|
||||
if getattr(torch, "xpu", None) is not None and torch.xpu.is_available():
|
||||
return torch.device("xpu")
|
||||
if torch.cuda.is_available():
|
||||
return 0
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _default_amp(device) -> bool:
|
||||
if isinstance(device, torch.device) and device.type == "xpu":
|
||||
return False
|
||||
if device == "cpu":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _patch_ultralytics_for_xpu():
|
||||
"""为 Ultralytics 打补丁,使其能在 XPU 环境下正常训练和验证。"""
|
||||
import ultralytics.engine.trainer as ut_trainer
|
||||
import ultralytics.engine.validator as ut_validator
|
||||
from ultralytics.utils.torch_utils import select_device as _original_select_device
|
||||
|
||||
# 1. 覆盖 select_device:Trainer 初始化传入 torch.device("xpu") 会走原版早返回;
|
||||
# 初始化后 args.device 会变成字符串 "xpu",中期 val 用 trainer.device,不调用 select_device;
|
||||
# 训练结束 final_eval 里 Validator 会 select_device("xpu"),且 validator 在 import 时已绑定原函数,
|
||||
# 只改 torch_utils 无效,必须同时修补 trainer/validator 模块内的引用。
|
||||
def _patched_select_device(device="", *args, **kwargs):
|
||||
# Ultralytics 8.4.x: select_device(device="", newline=False, verbose=True)
|
||||
# Older forks sometimes passed extra positional args; forward everything.
|
||||
if isinstance(device, str):
|
||||
d = device.strip().lower()
|
||||
if d == "xpu" or d.startswith("xpu:"):
|
||||
return torch.device(device.strip())
|
||||
return _original_select_device(device, *args, **kwargs)
|
||||
|
||||
import ultralytics.utils.torch_utils
|
||||
|
||||
ultralytics.utils.torch_utils.select_device = _patched_select_device
|
||||
ut_trainer.select_device = _patched_select_device
|
||||
ut_validator.select_device = _patched_select_device
|
||||
|
||||
# 2. 修补 Trainer 的内存函数
|
||||
BT = ut_trainer.BaseTrainer
|
||||
if not getattr(BT, "_archery_xpu_memory_patched", False):
|
||||
_orig_get_memory = BT._get_memory
|
||||
_orig_clear_memory = BT._clear_memory
|
||||
|
||||
def _get_memory(self, fraction=False):
|
||||
if self.device.type != "xpu":
|
||||
return _orig_get_memory(self, fraction)
|
||||
# ... (原有的 XPU 内存获取逻辑保持不变) ...
|
||||
memory, total = 0, 0
|
||||
try:
|
||||
idx = self.device.index
|
||||
if idx is None:
|
||||
idx = torch.xpu.current_device()
|
||||
memory = int(torch.xpu.memory_allocated(idx))
|
||||
if fraction:
|
||||
total = int(torch.xpu.get_device_properties(idx).total_memory)
|
||||
except Exception:
|
||||
pass
|
||||
return (memory / total) if fraction and total > 0 else (memory / 2**30)
|
||||
|
||||
def _clear_memory(self, threshold=None):
|
||||
if self.device.type != "xpu":
|
||||
return _orig_clear_memory(self, threshold)
|
||||
if threshold is not None:
|
||||
assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
|
||||
if self._get_memory(fraction=True) <= threshold:
|
||||
return
|
||||
gc.collect()
|
||||
if hasattr(torch.xpu, "empty_cache"):
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
BT._get_memory = _get_memory
|
||||
BT._clear_memory = _clear_memory
|
||||
BT._archery_xpu_memory_patched = True
|
||||
|
||||
# 3. 修补 Validator 的内存函数 (关键是添加这部分)
|
||||
BV = ut_validator.BaseValidator
|
||||
if not getattr(BV, "_archery_xpu_memory_patched", False):
|
||||
# 为 Validator 添加同样的内存处理方法
|
||||
BV._get_memory = _get_memory
|
||||
BV._clear_memory = _clear_memory
|
||||
BV._archery_xpu_memory_patched = True
|
||||
|
||||
|
||||
def _install_best_by_pixel_validate(data_yaml: str, imgsz: int, conf: float) -> None:
|
||||
"""用验证集关键点像素 mean 替代 mAP fitness,驱动 best.pt 与 patience early stopping。"""
|
||||
import ultralytics.engine.trainer as ut
|
||||
from ultralytics.utils import RANK
|
||||
|
||||
BT = ut.BaseTrainer
|
||||
if getattr(BT, "_archery_best_by_pixel_installed", False):
|
||||
return
|
||||
|
||||
_orig_validate = BT.validate
|
||||
|
||||
def validate(self):
|
||||
import torch.distributed as dist
|
||||
|
||||
if self.ema and self.world_size > 1:
|
||||
for buffer in self.ema.ema.buffers():
|
||||
dist.broadcast(buffer, src=0)
|
||||
metrics = self.validator(self)
|
||||
if metrics is None:
|
||||
return None, None
|
||||
orig_fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())
|
||||
|
||||
use_pixel = self.world_size <= 1 and RANK in {-1, 0}
|
||||
mean_px: float | None = None
|
||||
if use_pixel:
|
||||
tmp_path: str | None = None
|
||||
try:
|
||||
fd, tmp_path = tempfile.mkstemp(suffix=".pt", prefix="archery_pxfit_")
|
||||
os.close(fd)
|
||||
from ultralytics.utils.torch_utils import unwrap_model
|
||||
|
||||
core = unwrap_model(self.ema.ema if self.ema else self.model)
|
||||
torch.save({"ema": deepcopy(core).half(), "train_args": vars(self.args)}, tmp_path)
|
||||
probe = YOLO(tmp_path)
|
||||
stats = eval_val_pixel_error(
|
||||
probe,
|
||||
data_yaml,
|
||||
device=self.device,
|
||||
imgsz=imgsz,
|
||||
conf=conf,
|
||||
)
|
||||
mean_px = stats.get("mean_px")
|
||||
if mean_px is None:
|
||||
raise RuntimeError("无有效 mean_px(检查 val 标签与检测是否为空)")
|
||||
except Exception as exc:
|
||||
print(f"\n⚠️ [best-by-pixel] 像素探针失败,本 epoch 仍用 mAP fitness: {exc}\n")
|
||||
mean_px = None
|
||||
finally:
|
||||
if tmp_path:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if mean_px is not None:
|
||||
fitness = -float(mean_px)
|
||||
metrics["metrics/mean_px(val)"] = float(mean_px)
|
||||
else:
|
||||
fitness = float(orig_fitness)
|
||||
|
||||
if not self.best_fitness or self.best_fitness < fitness:
|
||||
self.best_fitness = fitness
|
||||
return metrics, fitness
|
||||
|
||||
BT.validate = validate
|
||||
BT._archery_best_by_pixel_installed = True
|
||||
|
||||
|
||||
def _fmt_csv_metric(v: float | int | None) -> str:
|
||||
if v is None:
|
||||
return ""
|
||||
if isinstance(v, float):
|
||||
return f"{v:.6g}"
|
||||
return str(v)
|
||||
|
||||
|
||||
# 写入 results.csv 的列名(与 --best-by-pixel 的 metrics/mean_px(val) 区分,避免被 last.pt 回调覆盖 EMA 行)
|
||||
_PIXEL_METRIC_COLUMNS: tuple[tuple[str, str], ...] = (
|
||||
("pixel_error/mean_px", "mean_px"),
|
||||
("pixel_error/median_px", "median_px"),
|
||||
("pixel_error/p95_px", "p95_px"),
|
||||
("pixel_error/max_px", "max_px"),
|
||||
("pixel_error/n_points", "n_points"),
|
||||
("pixel_error/n_images", "n_images"),
|
||||
("pixel_error/skip_no_det", "skip_no_det"),
|
||||
("pixel_error/skip_no_gt", "skip_no_gt"),
|
||||
("pixel_error/skip_kpt_mismatch", "skip_kpt_mismatch"),
|
||||
)
|
||||
|
||||
|
||||
def _merge_pixel_metrics_into_results_csv(save_dir: str | Path, epoch_1based: int, stats: dict) -> None:
|
||||
"""在 Ultralytics 写完本 epoch 行之后,把像素指标列合并进 results.csv(扩展表头、补空列)。"""
|
||||
csv_path = Path(save_dir) / "results.csv"
|
||||
if not csv_path.is_file():
|
||||
return
|
||||
try:
|
||||
with open(csv_path, newline="", encoding="utf-8") as f:
|
||||
rows = list(csv.reader(f))
|
||||
except OSError:
|
||||
return
|
||||
if len(rows) < 2:
|
||||
return
|
||||
header = list(rows[0])
|
||||
for col_name, _ in _PIXEL_METRIC_COLUMNS:
|
||||
if col_name not in header:
|
||||
header.append(col_name)
|
||||
for ri in range(1, len(rows)):
|
||||
rows[ri].append("")
|
||||
col_ix = {name: i for i, name in enumerate(header)}
|
||||
rows[0] = header
|
||||
target_ri: int | None = None
|
||||
for ri in range(1, len(rows)):
|
||||
row = rows[ri]
|
||||
while len(row) < len(header):
|
||||
row.append("")
|
||||
try:
|
||||
if int(float(row[0].strip())) == int(epoch_1based):
|
||||
target_ri = ri
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
if target_ri is None:
|
||||
return
|
||||
row = rows[target_ri]
|
||||
while len(row) < len(header):
|
||||
row.append("")
|
||||
for col_name, sk in _PIXEL_METRIC_COLUMNS:
|
||||
row[col_ix[col_name]] = _fmt_csv_metric(stats.get(sk))
|
||||
try:
|
||||
with open(csv_path, "w", newline="", encoding="utf-8") as f:
|
||||
w = csv.writer(f)
|
||||
w.writerows(rows)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _make_pixel_metrics_callback(data_yaml: str, every: int, imgsz: int, conf: float = 0.25):
|
||||
def on_fit_epoch_end(trainer):
|
||||
from ultralytics.utils import RANK
|
||||
|
||||
if RANK not in {-1, 0}:
|
||||
return
|
||||
if every <= 0:
|
||||
return
|
||||
ep = int(getattr(trainer, "epoch", -1))
|
||||
if (ep + 1) % every != 0:
|
||||
return
|
||||
w = Path(trainer.save_dir) / "weights" / "last.pt"
|
||||
if not w.is_file():
|
||||
return
|
||||
m = YOLO(str(w))
|
||||
stats = eval_val_pixel_error(
|
||||
m,
|
||||
data_yaml,
|
||||
device=trainer.device,
|
||||
imgsz=imgsz,
|
||||
conf=conf,
|
||||
)
|
||||
mean_px = stats.get("mean_px")
|
||||
p95_px = stats.get("p95_px")
|
||||
mean_s = f"{mean_px:.3f}" if mean_px is not None else "n/a"
|
||||
p95_s = f"{p95_px:.3f}" if p95_px is not None else "n/a"
|
||||
print(
|
||||
f"\n[pixel-metrics] epoch {ep + 1}: mean_px={mean_s} p95_px={p95_s} "
|
||||
f"n_points={stats.get('n_points', 0)} "
|
||||
f"skip(det/gt/k)={stats['skip_no_det']}/{stats['skip_no_gt']}/{stats['skip_kpt_mismatch']}\n"
|
||||
)
|
||||
_merge_pixel_metrics_into_results_csv(trainer.save_dir, ep + 1, stats)
|
||||
|
||||
return on_fit_epoch_end
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="YOLO Pose 训练(XPU/CUDA/CPU)")
|
||||
ap.add_argument("--data", default="datasets/dataset_pose.yaml", help="data.yaml")
|
||||
ap.add_argument("--model", default="yolo11x-pose.pt", help="预训练权重")
|
||||
ap.add_argument("--epochs", type=int, default=100)
|
||||
ap.add_argument("--imgsz", type=int, default=960, help="训练输入边长(默认 960)")
|
||||
ap.add_argument("--batch", type=int, default=4, help="批大小;OOM 时减小")
|
||||
ap.add_argument(
|
||||
"--device",
|
||||
default="auto",
|
||||
help="auto | xpu | 0 | cuda | cpu(auto:XPU 优先)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--no-amp",
|
||||
action="store_true",
|
||||
help="关闭混合精度(默认:CUDA 开启,XPU/CPU 关闭)",
|
||||
)
|
||||
ap.add_argument("--project", default="runs/pose")
|
||||
ap.add_argument("--name", default="target_pose_train")
|
||||
ap.add_argument("--workers", type=int, default=4)
|
||||
ap.add_argument(
|
||||
"--pixel-metrics-every",
|
||||
type=int,
|
||||
default=0,
|
||||
help="每 N 个 epoch 在 val 上打印像素误差并写入 results.csv 对应 epoch 行(0=关闭);需 labels 与 data.yaml 布局一致",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--pixel-metrics-conf",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="--pixel-metrics-every 时 predict 置信度阈值(默认 0.25)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--best-by-pixel",
|
||||
action="store_true",
|
||||
help="best.pt 与 early stopping 按验证集 mean 像素误差(同 pose_pixel_metrics),fitness=-mean_px;单卡有效,DDP 自动退回 mAP",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--pixel-fitness-conf",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="--best-by-pixel 时 predict 置信度阈值(默认与 pixel-metrics 一致)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--export-onnx",
|
||||
action="store_true",
|
||||
help="训练结束后导出 ONNX(需再设 --onnx-imgsz)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--onnx-imgsz",
|
||||
type=int,
|
||||
nargs=2,
|
||||
metavar=("H", "W"),
|
||||
default=[224, 320],
|
||||
help="导出 ONNX 的 [高, 宽],默认 224 320(Maix 常用)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--clear-label-cache",
|
||||
action="store_true",
|
||||
help="启动训练前删除 data.yaml 中 path 下的 labels/*.cache(修正标注后仍报 corrupt 时用)",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
device = _pick_device(None if args.device == "auto" else args.device)
|
||||
use_amp = False if args.no_amp else _default_amp(device)
|
||||
|
||||
if isinstance(device, torch.device) and device.type == "xpu":
|
||||
print(f"✅ 使用 Intel XPU: {device}")
|
||||
elif device == 0 or device == "0":
|
||||
print(f"✅ 使用 CUDA: {torch.cuda.get_device_name(0)}")
|
||||
else:
|
||||
print("⚠️ 使用 CPU,训练会较慢")
|
||||
|
||||
if isinstance(device, torch.device) and device.type == "xpu":
|
||||
_patch_ultralytics_for_xpu()
|
||||
|
||||
data_yaml = args.data
|
||||
if not os.path.isabs(data_yaml):
|
||||
data_yaml = os.path.join(os.path.dirname(os.path.abspath(__file__)), data_yaml)
|
||||
if not os.path.exists(data_yaml):
|
||||
print(f"❌ 数据集配置不存在: {data_yaml}")
|
||||
return
|
||||
|
||||
if args.clear_label_cache:
|
||||
n_rm = _clear_ultralytics_label_caches(data_yaml)
|
||||
print(f"🗑️ 已删除标签目录缓存 {n_rm} 个(labels/*.cache),将强制重新扫描标注。")
|
||||
|
||||
print(f"📦 加载模型: {args.model}(固定 task=pose)")
|
||||
model = YOLO(args.model, task="pose")
|
||||
|
||||
if args.best_by_pixel:
|
||||
_install_best_by_pixel_validate(data_yaml, args.imgsz, args.pixel_fitness_conf)
|
||||
print(
|
||||
"📌 已启用 --best-by-pixel:best.pt / patience 按验证集 mean 像素误差(fitness=-mean_px);"
|
||||
"反向传播仍为 Ultralytics 默认 pose/box loss。"
|
||||
)
|
||||
|
||||
if args.pixel_metrics_every > 0:
|
||||
model.add_callback(
|
||||
"on_fit_epoch_end",
|
||||
_make_pixel_metrics_callback(
|
||||
data_yaml, args.pixel_metrics_every, args.imgsz, conf=args.pixel_metrics_conf
|
||||
),
|
||||
)
|
||||
|
||||
model.train(
|
||||
task="pose",
|
||||
data=data_yaml,
|
||||
epochs=args.epochs,
|
||||
imgsz=args.imgsz,
|
||||
batch=args.batch,
|
||||
name=args.name,
|
||||
project=args.project,
|
||||
exist_ok=True,
|
||||
save=True,
|
||||
save_period=5,
|
||||
device=device,
|
||||
workers=args.workers,
|
||||
lr0=0.0001,
|
||||
lrf=0.01,
|
||||
optimizer="AdamW",
|
||||
momentum=0.937,
|
||||
weight_decay=0.001,
|
||||
warmup_epochs=0,
|
||||
warmup_momentum=0.8,
|
||||
warmup_bias_lr=0.1,
|
||||
hsv_h=0.015,
|
||||
hsv_s=0.7,
|
||||
hsv_v=0.4,
|
||||
degrees=5.0,
|
||||
translate=0.0,
|
||||
scale=0.2,
|
||||
shear=0.0,
|
||||
perspective=0.0000,
|
||||
flipud=0.0,
|
||||
fliplr=0.5,
|
||||
mosaic=0.0,
|
||||
mixup=0.0,
|
||||
copy_paste=0.0,
|
||||
box=6,
|
||||
cls=0.5,
|
||||
dfl=1.5,
|
||||
pose=18.0,
|
||||
kobj=0.5,
|
||||
freeze=0,
|
||||
seed=42,
|
||||
verbose=True,
|
||||
amp=use_amp,
|
||||
patience=100,
|
||||
cos_lr=True,
|
||||
)
|
||||
|
||||
print("\n✅ 训练完成!")
|
||||
print(f"📁 best: {args.project}/{args.name}/weights/best.pt")
|
||||
print(f"📁 last: {args.project}/{args.name}/weights/last.pt")
|
||||
print("📊 仅看像素误差可运行: python pose_pixel_metrics.py --model <best.pt> --data <yaml> --imgsz", args.imgsz)
|
||||
|
||||
if args.export_onnx:
|
||||
h, w = args.onnx_imgsz
|
||||
print(f"📦 导出 ONNX imgsz=[{h}, {w}] ...")
|
||||
model.export(format="onnx", imgsz=[h, w], simplify=True, opset=17, dynamic=False)
|
||||
print("✅ ONNX 完成")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
6
triangle_positions.json
Normal file
6
triangle_positions.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"0": [-20.0, -20.0, 0.0],
|
||||
"1": [-20.0, 20.0, 0.0],
|
||||
"2": [ 20.0, 20.0, 0.0],
|
||||
"3": [ 20.0, -20.0, 0.0]
|
||||
}
|
||||
1865
triangle_target.py
Normal file
1865
triangle_target.py
Normal file
File diff suppressed because it is too large
Load Diff
30
version.py
Normal file
30
version.py
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
应用版本号
|
||||
每次 OTA 更新时,只需要更新这个文件中的版本号
|
||||
"""
|
||||
VERSION = '1.2.12'
|
||||
|
||||
# 1.2.0 开始使用C++编译成.so,替换部分代码
|
||||
# 1.2.1 ota使用加密包
|
||||
# 1.2.2 支持wifi ota,并且设定时区,并使用单独线程保存图片
|
||||
# 1.2.3 修改ADC_TRIGGER_THRESHOLD 为2300,支持上传日志到服务器
|
||||
# 1.2.4 修改ADC_TRIGGER_THRESHOLD 为3000,并默认关闭摄像头的显示,并把ADC的采样间隔从50ms降低到10ms
|
||||
# 1.2.5 支持空气传感器采样,并默认关闭日志。优化断网时的发送队列丢消息问题,解决 WiFi 断线检测不可靠问题。
|
||||
# 1.2.6 在链接 wifi 前先判断 wifi 的可用性,假如不可用,则不落盘。增加日志批量压缩上传功能
|
||||
# 1.2.7 修复OTA失败的bug, 空气压力传感器的阈值是2500
|
||||
# 1.2.8 (1) 加快 wifi 下数据传输的速度。(2) 调整射箭时处理的逻辑,优先上报数据,再存照片之类的操作。(3)假如是用户打开激光的,射箭触发后不再关闭激光,因为是调瞄阶段
|
||||
# 1.2.9 增加电源板的控制和自动关机的功能
|
||||
# 1.2.10 config formal
|
||||
# 1.2.11 增加三角形的单应性算法,适配对应的靶纸
|
||||
# 1.2.110 关掉了黑色三角形算法,只用于测试
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
685
wifi.py
Normal file
685
wifi.py
Normal file
@@ -0,0 +1,685 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
WiFi管理模块
|
||||
提供WiFi连接、网络检测、质量监测等功能
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import threading
|
||||
import time as std_time
|
||||
from maix import time
|
||||
|
||||
import config
|
||||
from logger_manager import logger_manager
|
||||
from wpa_supplicant_conf import build_sta_conf_open, build_sta_conf_psk
|
||||
|
||||
|
||||
class WiFiManager:
|
||||
"""WiFi管理器(单例)"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(WiFiManager, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# WiFi 相关状态
|
||||
self._wifi_connected = False
|
||||
self._wifi_ip = None
|
||||
self._wifi_socket = None
|
||||
self._wifi_socket_lock = threading.Lock()
|
||||
self._prefer_wifi = True # 是否优先使用 WiFi
|
||||
self._recv_buffer = b"" # TCP 接收缓冲区
|
||||
|
||||
# WiFi 质量监测(后台线程)
|
||||
self._wifi_quality_monitor_thread = None
|
||||
self._wifi_quality_stop_event = threading.Event()
|
||||
self._last_wifi_rtt_ms = None # 最近一次测量的 RTT
|
||||
self._last_wifi_rssi_dbm = None # 最近一次测量的 RSSI
|
||||
|
||||
# 服务器相关(用于网络检测)
|
||||
try:
|
||||
import archery_netcore as _netcore
|
||||
self._server_ip = _netcore.get_config().get("SERVER_IP")
|
||||
self._server_port = _netcore.get_config().get("SERVER_PORT")
|
||||
except Exception:
|
||||
self._server_ip = getattr(config, "SERVER_IP", None)
|
||||
self._server_port = getattr(config, "SERVER_PORT", None)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
"""获取 logger 对象"""
|
||||
return logger_manager.logger
|
||||
|
||||
@property
|
||||
def wifi_connected(self):
|
||||
"""WiFi是否已连接"""
|
||||
return self._wifi_connected
|
||||
|
||||
@property
|
||||
def wifi_ip(self):
|
||||
"""WiFi IP地址"""
|
||||
return self._wifi_ip
|
||||
|
||||
@property
|
||||
def wifi_socket(self):
|
||||
"""WiFi socket对象"""
|
||||
return self._wifi_socket
|
||||
|
||||
@wifi_socket.setter
|
||||
def wifi_socket(self, value):
|
||||
"""设置WiFi socket对象"""
|
||||
self._wifi_socket = value
|
||||
|
||||
@property
|
||||
def wifi_socket_lock(self):
|
||||
"""获取WiFi socket锁"""
|
||||
return self._wifi_socket_lock
|
||||
|
||||
@property
|
||||
def prefer_wifi(self):
|
||||
"""是否优先使用WiFi"""
|
||||
return self._prefer_wifi
|
||||
|
||||
@prefer_wifi.setter
|
||||
def prefer_wifi(self, value):
|
||||
"""设置是否优先使用WiFi"""
|
||||
self._prefer_wifi = value
|
||||
|
||||
@property
|
||||
def last_wifi_rtt_ms(self):
|
||||
"""最近一次测量的RTT"""
|
||||
return self._last_wifi_rtt_ms
|
||||
|
||||
@property
|
||||
def last_wifi_rssi_dbm(self):
|
||||
"""最近一次测量的RSSI"""
|
||||
return self._last_wifi_rssi_dbm
|
||||
|
||||
@property
|
||||
def recv_buffer(self):
|
||||
"""TCP接收缓冲区"""
|
||||
return self._recv_buffer
|
||||
|
||||
@recv_buffer.setter
|
||||
def recv_buffer(self, value):
|
||||
"""设置TCP接收缓冲区"""
|
||||
self._recv_buffer = value
|
||||
|
||||
# ==================== WiFi 连接方法 ====================
|
||||
|
||||
def is_sta_associated(self):
|
||||
"""
|
||||
是否作为 STA 已关联到上游 AP(用于与 AP 模式区分:AP 模式下 wlan0 可能有 IP 但 iw link 为 Not connected)。
|
||||
"""
|
||||
try:
|
||||
out = os.popen("iw dev wlan0 link 2>/dev/null").read()
|
||||
if not out.strip():
|
||||
return False
|
||||
if "Not connected" in out:
|
||||
return False
|
||||
return "Connected to" in out
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def is_wifi_connected(self):
|
||||
"""检查WiFi是否已连接"""
|
||||
# AP 模式下 wlan0 也可能有 IP(如 192.168.66.1),但这不代表已作为 STA 连上路由器。
|
||||
# 业务侧(选网/TCP)只应在 STA 已关联到上游 AP 时认为 WiFi 可用。
|
||||
if not self.is_sta_associated():
|
||||
self._wifi_connected = False
|
||||
return False
|
||||
|
||||
# 优先用 MaixPy network(如果可用)
|
||||
try:
|
||||
from maix import network
|
||||
wifi = network.wifi.Wifi()
|
||||
if wifi.is_connected():
|
||||
self._wifi_connected = True
|
||||
# MaixPy 的 is_connected 可能不会同步填充 IP,这里用系统命令补齐一次
|
||||
try:
|
||||
ip = os.popen("ifconfig wlan0 2>/dev/null | grep 'inet ' | awk '{print $2}'").read().strip()
|
||||
if ip:
|
||||
self._wifi_ip = ip
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
except:
|
||||
self.logger.warning("Failed to check WiFi connection using MaixPy network", exc_info=True)
|
||||
|
||||
# 兜底:看系统 wlan0 有没有 IP
|
||||
try:
|
||||
ip = os.popen("ifconfig wlan0 2>/dev/null | grep 'inet ' | awk '{print $2}'").read().strip()
|
||||
if ip:
|
||||
self._wifi_connected = True
|
||||
self._wifi_ip = ip
|
||||
return True
|
||||
except:
|
||||
self.logger.warning("Failed to check WiFi connection using system command", exc_info=True)
|
||||
|
||||
self._wifi_connected = False
|
||||
return False
|
||||
|
||||
def connect_wifi(self, ssid, password, verify_callback=None, persist=True, timeout_s=20):
|
||||
"""
|
||||
连接 Wi-Fi(唯一实现:写 wpa_supplicant + /boot 凭证,MaixPy Wifi.connect,再等 IP 与可选校验)。
|
||||
|
||||
``NetworkManager.connect_wifi`` 仅封装本方法(通过 ``verify_callback`` 传入 host/port 校验)。
|
||||
|
||||
重要:``/boot/wpa_supplicant.conf`` 存在时 S30wifi 会优先 cp,避免 shell 传中文 SSID。
|
||||
|
||||
Args:
|
||||
ssid: WiFi SSID
|
||||
password: WiFi密码
|
||||
verify_callback: 可选;``(ip) -> (success: bool, error: str)``,在拿到 IP 后调用
|
||||
persist: 是否持久化保存凭证(False 时成功后回滚 /boot 与 /etc 中的本次写入)
|
||||
timeout_s: 等待 DHCP / 轮询 IP 的超时基数(秒);Maix 连接超时亦据此推导
|
||||
|
||||
Returns:
|
||||
(ip, error): IP地址和错误信息(成功时 error 为 None)
|
||||
"""
|
||||
# 配置文件路径定义
|
||||
conf_path = "/etc/wpa_supplicant.conf"
|
||||
boot_wpa_path = "/boot/wpa_supplicant.conf"
|
||||
ssid_file = "/boot/wifi.ssid"
|
||||
pass_file = "/boot/wifi.pass"
|
||||
|
||||
def _read_text(path: str):
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _write_text(path: str, content: str):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
def _restore_boot(old_ssid: str | None, old_pass: str | None):
|
||||
# 还原 /boot 凭证:原来没有就删除,原来有就写回
|
||||
try:
|
||||
if old_ssid is None:
|
||||
if os.path.exists(ssid_file):
|
||||
os.remove(ssid_file)
|
||||
else:
|
||||
_write_text(ssid_file, old_ssid)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if old_pass is None:
|
||||
if os.path.exists(pass_file):
|
||||
os.remove(pass_file)
|
||||
else:
|
||||
_write_text(pass_file, old_pass)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _restore_boot_wpa(old_wpa: str | None):
|
||||
try:
|
||||
if old_wpa is None:
|
||||
if os.path.exists(boot_wpa_path):
|
||||
os.remove(boot_wpa_path)
|
||||
else:
|
||||
_write_text(boot_wpa_path, old_wpa)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
old_conf = _read_text(conf_path)
|
||||
old_boot_ssid = _read_text(ssid_file)
|
||||
old_boot_pass = _read_text(pass_file)
|
||||
old_boot_wpa = _read_text(boot_wpa_path) if os.path.exists(boot_wpa_path) else None
|
||||
|
||||
try:
|
||||
try:
|
||||
full_conf = build_sta_conf_psk(ssid.strip(), password.strip())
|
||||
except ValueError as ve:
|
||||
raise RuntimeError(str(ve)) from ve
|
||||
|
||||
try:
|
||||
_write_text(conf_path, full_conf)
|
||||
except Exception:
|
||||
pass
|
||||
_write_text(boot_wpa_path, full_conf)
|
||||
|
||||
# 仍写入 ssid/pass,便于其它脚本/人工查看;S30wifi 优先使用 wpa_supplicant.conf
|
||||
_write_text(ssid_file, ssid.strip())
|
||||
_write_text(pass_file, password.strip())
|
||||
|
||||
from maix import err as maix_err
|
||||
from maix import network as maix_net
|
||||
|
||||
self.logger.info(f"[WIFI] Maix connect start ssid={ssid!r}")
|
||||
w = maix_net.wifi.Wifi()
|
||||
connect_timeout_s = int(timeout_s) if timeout_s and timeout_s > 0 else 60
|
||||
connect_timeout_s = max(10, min(connect_timeout_s, 120))
|
||||
e = w.connect(ssid, password, wait=True, timeout=connect_timeout_s)
|
||||
maix_err.check_raise(e, "connect wifi failed")
|
||||
try:
|
||||
maix_ip = w.get_ip()
|
||||
except Exception:
|
||||
maix_ip = None
|
||||
self.logger.info(f"[WIFI] Maix connect ok ip={maix_ip!r}")
|
||||
|
||||
# 等待获取 IP
|
||||
wait_s = int(timeout_s) if timeout_s and timeout_s > 0 else 20
|
||||
wait_s = min(max(wait_s, 5), 60)
|
||||
for _ in range(wait_s):
|
||||
ip = os.popen("ifconfig wlan0 2>/dev/null | grep 'inet ' | awk '{print $2}'").read().strip()
|
||||
if ip:
|
||||
# 拿到 IP 不代表可上网/可访问目标;继续做可达性验证
|
||||
self._wifi_connected = True
|
||||
self._wifi_ip = ip
|
||||
self.logger.info(f"[WIFI] 已连接,IP: {ip},开始验证网络可用性...")
|
||||
|
||||
# 验证能访问指定目标(通过回调函数)
|
||||
if verify_callback:
|
||||
success, error = verify_callback(ip)
|
||||
if not success:
|
||||
raise RuntimeError(error or "Verification failed")
|
||||
|
||||
# ====== 验证通过 ======
|
||||
if not persist:
|
||||
# 不持久化:把 /boot 恢复成旧值(不重启,当前连接保持不变)
|
||||
_restore_boot(old_boot_ssid, old_boot_pass)
|
||||
_restore_boot_wpa(old_boot_wpa)
|
||||
self.logger.info("[WIFI] 网络验证通过,但按 persist=False 回滚 /boot 凭证(不重启)")
|
||||
else:
|
||||
self.logger.info("[WIFI] 网络验证通过,/boot 凭证已保留(持久化)")
|
||||
|
||||
return ip, None
|
||||
|
||||
std_time.sleep(1)
|
||||
|
||||
raise RuntimeError("Timeout: No IP obtained")
|
||||
|
||||
except Exception as e:
|
||||
# 失败:回滚 /boot 和 /etc,重启 WiFi 恢复旧网络
|
||||
_restore_boot(old_boot_ssid, old_boot_pass)
|
||||
_restore_boot_wpa(old_boot_wpa)
|
||||
try:
|
||||
if old_conf is not None:
|
||||
_write_text(conf_path, old_conf)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
os.system("/etc/init.d/S30wifi restart")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._wifi_connected = False
|
||||
self._wifi_ip = None
|
||||
self.logger.error(f"[WIFI] 连接/验证失败,已回滚: {e}")
|
||||
return None, str(e)
|
||||
|
||||
def persist_sta_credentials(self, ssid: str, password: str, restart_service: bool = True):
|
||||
"""
|
||||
仅写入 STA 凭证(/etc/wpa_supplicant.conf、/boot/wpa_supplicant.conf、/boot/wifi.ssid|pass),
|
||||
可选是否立即 /etc/init.d/S30wifi restart。
|
||||
不做可达性验证。用于热点配网页提交后切换到连接指定路由器。
|
||||
password 为空时按开放网络(key_mgmt=NONE)写入。
|
||||
Returns:
|
||||
(ok: bool, err_msg: str)
|
||||
"""
|
||||
ssid = (ssid or "").strip()
|
||||
password = (password or "").strip()
|
||||
if not ssid:
|
||||
return False, "SSID 为空"
|
||||
|
||||
conf_path = "/etc/wpa_supplicant.conf"
|
||||
boot_wpa_path = "/boot/wpa_supplicant.conf"
|
||||
ssid_file = "/boot/wifi.ssid"
|
||||
pass_file = "/boot/wifi.pass"
|
||||
|
||||
def _write_text(path: str, content: str):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
try:
|
||||
if password:
|
||||
full_conf = build_sta_conf_psk(ssid, password)
|
||||
else:
|
||||
full_conf = build_sta_conf_open(ssid)
|
||||
_write_text(conf_path, full_conf)
|
||||
_write_text(boot_wpa_path, full_conf)
|
||||
except ValueError as e:
|
||||
return False, str(e)
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
try:
|
||||
_write_text(ssid_file, ssid)
|
||||
_write_text(pass_file, password)
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
if restart_service:
|
||||
try:
|
||||
os.system("/etc/init.d/S30wifi restart")
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
self.logger.info(f"[WIFI] persist_sta_credentials: 已写入并重启 S30wifi, ssid={ssid!r}")
|
||||
else:
|
||||
self.logger.info(f"[WIFI] persist_sta_credentials: 已写入凭证(未重启 S30wifi), ssid={ssid!r}")
|
||||
return True, ""
|
||||
|
||||
def disconnect_wifi(self):
|
||||
"""断开WiFi连接并清理资源"""
|
||||
if self._wifi_socket:
|
||||
try:
|
||||
self._wifi_socket.close()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._wifi_socket = None
|
||||
self._wifi_connected = False
|
||||
self._wifi_ip = None
|
||||
|
||||
# ==================== WiFi 质量监测 ====================
|
||||
|
||||
def _get_wifi_rssi_dbm(self):
|
||||
"""
|
||||
获取 WiFi 信号强度(dBm,越大越好;比如 -40 比 -80 好)
|
||||
由于不同固件实现差异,这里做多策略兜底,失败返回 None
|
||||
"""
|
||||
# 1) 优先使用:iw dev wlan0 link
|
||||
# 你提供的输出示例包含:signal: -58 dBm
|
||||
try:
|
||||
out = os.popen("iw dev wlan0 link 2>/dev/null").read()
|
||||
if out:
|
||||
m = re.search(r"signal:\s*(-?\d+(?:\.\d+)?)\s*dBm", out, re.IGNORECASE)
|
||||
if m:
|
||||
v = float(m.group(1))
|
||||
# 合理范围兜底
|
||||
if -120.0 <= v <= 0.0:
|
||||
return v
|
||||
m2 = re.search(r"signal:\s*(-?\d+(?:\.\d+)?)", out, re.IGNORECASE)
|
||||
if m2:
|
||||
v = float(m2.group(1))
|
||||
if -120.0 <= v <= 0.0:
|
||||
return v
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2) 兜底:iwconfig
|
||||
try:
|
||||
out = os.popen("iwconfig wlan0 2>/dev/null").read()
|
||||
m = re.search(r"Signal level[=:]\s*(-?\d+(?:\.\d+)?)\s*dBm", out, re.IGNORECASE)
|
||||
if m:
|
||||
v = float(m.group(1))
|
||||
if -120.0 <= v <= 0.0:
|
||||
return v
|
||||
m2 = re.search(r"Signal level[=:]\s*(-?\d+(?:\.\d+)?)", out, re.IGNORECASE)
|
||||
if m2:
|
||||
v = float(m2.group(1))
|
||||
if -120.0 <= v <= 0.0:
|
||||
return v
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _measure_wifi_tcp_rtt_ms(self, host, port, samples=3, per_sample_timeout_ms=900):
|
||||
"""
|
||||
测量:在当前 WiFi 下,TCP 建连耗时(RTT 的近似)
|
||||
|
||||
Args:
|
||||
host: 目标主机
|
||||
port: 目标端口
|
||||
samples: 采样次数
|
||||
per_sample_timeout_ms: 每次采样超时时间(毫秒)
|
||||
|
||||
Returns:
|
||||
(median_rtt_ms, reachable_bool)
|
||||
"""
|
||||
rtts = []
|
||||
reachable = False
|
||||
addr = None
|
||||
|
||||
# 先解析一次地址,避免每次样本都做 DNS
|
||||
try:
|
||||
addr_info = socket.getaddrinfo(host, port)[0]
|
||||
addr = (addr_info[0], addr_info[1], addr_info[2], addr_info[-1])
|
||||
except Exception:
|
||||
return float("inf"), False
|
||||
|
||||
for _ in range(max(1, int(samples or 1))):
|
||||
s = None
|
||||
try:
|
||||
s = socket.socket(addr[0], addr[1], addr[2])
|
||||
s.settimeout(max(0.1, float(per_sample_timeout_ms) / 1000.0))
|
||||
t0 = time.ticks_ms()
|
||||
s.connect(addr[-1])
|
||||
elapsed_ms = abs(time.ticks_diff(time.ticks_ms(), t0))
|
||||
rtts.append(float(elapsed_ms))
|
||||
reachable = True
|
||||
except Exception:
|
||||
# 单个样本失败不影响整体,只要有成功样本就继续
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
if s:
|
||||
s.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 小间隔,避免过度占用
|
||||
try:
|
||||
time.sleep_ms(100)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not rtts:
|
||||
return float("inf"), False
|
||||
|
||||
rtts_sorted = sorted(rtts)
|
||||
mid = len(rtts_sorted) // 2
|
||||
if len(rtts_sorted) % 2 == 1:
|
||||
median = rtts_sorted[mid]
|
||||
else:
|
||||
median = (rtts_sorted[mid - 1] + rtts_sorted[mid]) / 2.0
|
||||
return median, reachable
|
||||
|
||||
def _is_wifi_quality_bad(self, wifi_rtt_ms, wifi_rssi_dbm):
|
||||
"""
|
||||
综合判断 WiFi 质量是否差:
|
||||
- RTT中位数超过阈值 -> bad
|
||||
- 若启用 RSSI:信号弱(RSSI更差于阈值) 且 RTT 也偏高 -> bad
|
||||
"""
|
||||
if wifi_rtt_ms >= config.WIFI_QUALITY_RTT_BAD_MS:
|
||||
return True
|
||||
|
||||
if not getattr(config, "WIFI_QUALITY_USE_RSSI", False):
|
||||
return False
|
||||
|
||||
if wifi_rssi_dbm is None:
|
||||
return False
|
||||
|
||||
# "rtt_warn + rssi_bad" 联合条件
|
||||
if wifi_rtt_ms >= config.WIFI_QUALITY_RTT_WARN_MS and wifi_rssi_dbm <= config.WIFI_QUALITY_RSSI_BAD_DBM:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_wifi_quality_status(self):
|
||||
"""
|
||||
获取当前 WiFi 质量状态(用于调试或显示)
|
||||
|
||||
Returns:
|
||||
dict: {"rtt_ms": float, "rssi_dbm": float, "is_bad": bool}
|
||||
"""
|
||||
rtt = self._last_wifi_rtt_ms
|
||||
rssi = self._last_wifi_rssi_dbm
|
||||
is_bad = False
|
||||
|
||||
if rtt is not None and rtt != float("inf"):
|
||||
is_bad = self._is_wifi_quality_bad(rtt, rssi)
|
||||
|
||||
return {
|
||||
"rtt_ms": rtt if rtt is not None and rtt != float("inf") else None,
|
||||
"rssi_dbm": rssi,
|
||||
"is_bad": is_bad
|
||||
}
|
||||
|
||||
# ==================== 后台质量监测线程 ====================
|
||||
|
||||
def start_quality_monitor(self, network_type_callback, on_poor_quality_callback):
|
||||
"""
|
||||
启动 WiFi 质量后台监测线程(每 5 秒测量一次 RTT 和 RSSI)
|
||||
只在 WiFi 连接时运行,不影响业务发送性能
|
||||
|
||||
Args:
|
||||
network_type_callback: 获取当前网络类型的回调函数
|
||||
on_poor_quality_callback: WiFi质量差时的回调函数
|
||||
"""
|
||||
if self._wifi_quality_monitor_thread is not None:
|
||||
self.logger.warning("[WiFi Monitor] 监测线程已在运行")
|
||||
return
|
||||
|
||||
self._network_type_callback = network_type_callback
|
||||
self._on_poor_quality_callback = on_poor_quality_callback
|
||||
self._wifi_quality_stop_event.clear()
|
||||
self._wifi_quality_monitor_thread = threading.Thread(
|
||||
target=self._quality_monitor_loop,
|
||||
daemon=True,
|
||||
name="wifi_quality_monitor"
|
||||
)
|
||||
self._wifi_quality_monitor_thread.start()
|
||||
self.logger.info("[WiFi Monitor] 已启动后台监测线程")
|
||||
|
||||
def stop_quality_monitor(self):
|
||||
"""停止 WiFi 质量监测线程"""
|
||||
if self._wifi_quality_monitor_thread is None:
|
||||
return
|
||||
|
||||
self._wifi_quality_stop_event.set()
|
||||
try:
|
||||
self._wifi_quality_monitor_thread.join(timeout=2.0)
|
||||
except Exception as e:
|
||||
self.logger.error(f"[WiFi Monitor] 停止线程失败:{e}")
|
||||
finally:
|
||||
self._wifi_quality_monitor_thread = None
|
||||
self.logger.info("[WiFi Monitor] 已停止后台监测线程")
|
||||
|
||||
def _quality_monitor_loop(self):
|
||||
"""
|
||||
WiFi 质量监测循环(后台线程)
|
||||
每 5 秒测量一次 RTT 和 RSSI,发现质量差则触发切换
|
||||
"""
|
||||
while not self._wifi_quality_stop_event.is_set():
|
||||
try:
|
||||
# 只在 WiFi 连接时才测量
|
||||
network_type = self._network_type_callback()
|
||||
if network_type == "wifi" and self._wifi_socket:
|
||||
# # 测量 RTT(1 个样本,快速测量)
|
||||
# rtt_ms, reachable = self._measure_wifi_tcp_rtt_ms(
|
||||
# self._server_ip, self._server_port,
|
||||
# samples=1, per_sample_timeout_ms=600
|
||||
# )
|
||||
|
||||
# 获取 RSSI
|
||||
rssi_dbm = self._get_wifi_rssi_dbm()
|
||||
|
||||
# 更新缓存
|
||||
# 不使用 RTT 测量
|
||||
rtt_ms = 0
|
||||
reachable = True
|
||||
self._last_wifi_rtt_ms = rtt_ms if reachable else None
|
||||
self._last_wifi_rssi_dbm = rssi_dbm
|
||||
_rssi_s = f"{rssi_dbm:.0f}" if rssi_dbm is not None else "n/a"
|
||||
self.logger.debug(f"[WiFi Monitor] - RTT={rtt_ms:.0f}ms, RSSI={_rssi_s}dBm")
|
||||
|
||||
# 判断质量是否差(切换前做 2 次快速复测,防止瞬时抖动)
|
||||
def _is_bad_now(_reachable, _rtt, _rssi):
|
||||
if (not _reachable) or (_rtt is None) or (_rtt == float("inf")):
|
||||
return True
|
||||
return self._is_wifi_quality_bad(_rtt, _rssi)
|
||||
|
||||
bad = _is_bad_now(reachable, rtt_ms, rssi_dbm)
|
||||
if bad:
|
||||
self.logger.warning("[WiFi Monitor] 质量差,切换前快速重试 2 次(每次间隔1秒)")
|
||||
|
||||
for retry_idx in range(2):
|
||||
time.sleep_ms(1000)
|
||||
# 不使用 RTT 测量
|
||||
rtt2 = 0
|
||||
reachable2 = True
|
||||
# rtt2, reachable2 = self._measure_wifi_tcp_rtt_ms(
|
||||
# self._server_ip, self._server_port,
|
||||
# samples=1, per_sample_timeout_ms=600
|
||||
# )
|
||||
rssi2 = self._get_wifi_rssi_dbm()
|
||||
|
||||
# 更新缓存,便于外部查看最新状态
|
||||
self._last_wifi_rtt_ms = rtt2 if reachable2 else None
|
||||
self._last_wifi_rssi_dbm = rssi2
|
||||
|
||||
bad2 = _is_bad_now(reachable2, rtt2, rssi2)
|
||||
try:
|
||||
_rtt_disp = (
|
||||
rtt2
|
||||
if rtt2 is not None and rtt2 != float("inf")
|
||||
else -1
|
||||
)
|
||||
self.logger.info(
|
||||
f"[WiFi Monitor] 复测{retry_idx+1}/2: reachable={reachable2}, "
|
||||
f"rtt={_rtt_disp:.0f}ms, rssi={rssi2}, bad={bad2}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not bad2:
|
||||
self.logger.info("[WiFi Monitor] 复测恢复正常,继续保留 WiFi(不切换)")
|
||||
bad = False
|
||||
break
|
||||
|
||||
if bad:
|
||||
self.logger.warning("[WiFi Monitor] 复测仍差/不通,尝试切换到 4G")
|
||||
self._on_poor_quality_callback()
|
||||
|
||||
# 休眠 5 秒
|
||||
time.sleep(5)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"[WiFi Monitor] 监测异常:{e}")
|
||||
# 异常后继续循环,避免线程退出
|
||||
continue
|
||||
|
||||
|
||||
# 全局 WiFi 管理器实例
|
||||
wifi_manager = WiFiManager()
|
||||
|
||||
|
||||
# ==================== 兼容旧接口的函数 ====================
|
||||
|
||||
def is_wifi_connected():
|
||||
"""尽量判断当前是否有 Wi-Fi(有则走 Wi-Fi OTA,否则走 4G OTA)"""
|
||||
return wifi_manager.is_wifi_connected()
|
||||
|
||||
|
||||
def connect_wifi(ssid, password, verify_callback=None, persist=True, timeout_s=20):
|
||||
"""
|
||||
连接 Wi-Fi 并将凭证持久化保存到 /boot/ 目录,
|
||||
以便设备重启后自动连接。
|
||||
|
||||
Args:
|
||||
ssid: WiFi SSID
|
||||
password: WiFi密码
|
||||
verify_callback: 验证回调函数,接收 (ip) 参数,返回 (success: bool, error: str)
|
||||
persist: 是否持久化保存
|
||||
timeout_s: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
(ip, error): IP地址和错误信息(成功时error为None)
|
||||
"""
|
||||
return wifi_manager.connect_wifi(ssid, password, verify_callback, persist, timeout_s)
|
||||
521
wifi_config_httpd.py
Normal file
521
wifi_config_httpd.py
Normal file
@@ -0,0 +1,521 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
WiFi 热点配网:迷你 HTTP 服务器(仅 GET/POST,标准库 socket),独立线程运行。
|
||||
|
||||
策略(与 /etc/init.d/S30wifi 一致):
|
||||
- 仅当 STA 未连上 WiFi 且 4G 也不可用时,写入 /boot/wifi.ap、去掉 /boot/wifi.sta,
|
||||
并重启 S30wifi 由系统起热点;再在本进程起 HTTP。
|
||||
- 用户 POST 提交路由器 SSID/密码后:仅写凭证、stop S30wifi、删 /boot/wifi.ap、建 /boot/wifi.sta、sync、reboot。
|
||||
"""
|
||||
import html
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import time as std_time
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import config
|
||||
from logger_manager import logger_manager
|
||||
from wifi import wifi_manager
|
||||
|
||||
|
||||
_http_thread = None
|
||||
_http_stop = threading.Event()
|
||||
|
||||
|
||||
def _http_response(status, body_bytes, content_type="text/html; charset=utf-8"):
|
||||
head = (
|
||||
f"HTTP/1.1 {status}\r\n"
|
||||
f"Content-Type: {content_type}\r\n"
|
||||
f"Content-Length: {len(body_bytes)}\r\n"
|
||||
f"Connection: close\r\n"
|
||||
f"\r\n"
|
||||
).encode("utf-8")
|
||||
return head + body_bytes
|
||||
|
||||
|
||||
def _read_http_request(conn, max_total=65536):
|
||||
"""返回 (method, path, headers_str, body_bytes) 或 None。"""
|
||||
buf = b""
|
||||
while b"\r\n\r\n" not in buf and len(buf) < max_total:
|
||||
chunk = conn.recv(4096)
|
||||
if not chunk:
|
||||
break
|
||||
buf += chunk
|
||||
if b"\r\n\r\n" not in buf:
|
||||
return None
|
||||
idx = buf.index(b"\r\n\r\n")
|
||||
header_bytes = buf[:idx]
|
||||
rest = buf[idx + 4 :]
|
||||
try:
|
||||
headers_str = header_bytes.decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
headers_str = ""
|
||||
lines = headers_str.split("\r\n")
|
||||
if not lines:
|
||||
return None
|
||||
parts = lines[0].split()
|
||||
method = parts[0] if parts else "GET"
|
||||
path = parts[1] if len(parts) > 1 else "/"
|
||||
|
||||
content_length = 0
|
||||
for line in lines[1:]:
|
||||
if line.lower().startswith("content-length:"):
|
||||
try:
|
||||
content_length = int(line.split(":", 1)[1].strip())
|
||||
except Exception:
|
||||
content_length = 0
|
||||
break
|
||||
|
||||
body = rest
|
||||
while content_length > 0 and len(body) < content_length and len(body) < max_total:
|
||||
chunk = conn.recv(4096)
|
||||
if not chunk:
|
||||
break
|
||||
body += chunk
|
||||
body = body[:content_length]
|
||||
return method, path, headers_str, body
|
||||
|
||||
|
||||
def _page_form(msg_html=""):
|
||||
# 页面展示的热点名:以 /boot/wifi.ssid 为准(与实际 AP 保持一致)
|
||||
try:
|
||||
if os.path.exists("/boot/wifi.ssid"):
|
||||
with open("/boot/wifi.ssid", "r", encoding="utf-8") as f:
|
||||
_ssid = f.read().strip()
|
||||
else:
|
||||
_ssid = ""
|
||||
except Exception:
|
||||
_ssid = ""
|
||||
ap_ssid = html.escape(_ssid or getattr(config, "WIFI_CONFIG_AP_SSID", "ArcherySetup"))
|
||||
port = int(getattr(config, "WIFI_CONFIG_HTTP_PORT", 8080))
|
||||
ap_ip = html.escape(getattr(config, "WIFI_CONFIG_AP_IP", "192.168.66.1"))
|
||||
body = f"""<!DOCTYPE html>
|
||||
<html><head><meta charset="utf-8"/><meta name="viewport" content="width=device-width,initial-scale=1"/>
|
||||
<title>WiFi 配网</title></head><body>
|
||||
<h1>WiFi 配网</h1>
|
||||
<p>热点:<b>{ap_ssid}</b> · 端口 <b>{port}</b></p>
|
||||
<p>请填写要连接的<b>路由器</b> SSID 与密码(用于 STA 上网,不是热点密码)。提交后将关闭热点、保存并<b>重启设备</b>。</p>
|
||||
{msg_html}
|
||||
<form method="POST" action="/" accept-charset="utf-8">
|
||||
<p>SSID<br/><input name="ssid" type="text" style="width:100%;max-width:320px" required/></p>
|
||||
<p>密码(开放网络可留空)<br/><input name="password" type="password" style="width:100%;max-width:320px"/></p>
|
||||
<p><button type="submit">保存并重启</button></p>
|
||||
</form>
|
||||
<p style="color:#666;font-size:12px">提示:提交后设备会重启;请手机改连路由器 WiFi。</p>
|
||||
</body></html>"""
|
||||
return body.encode("utf-8")
|
||||
|
||||
|
||||
def _apply_sta_and_reboot(router_ssid: str, router_password: str):
|
||||
"""
|
||||
写路由器 STA 凭证 -> 停 WiFi 服务 -> 删 /boot/wifi.ap -> 建 /boot/wifi.sta -> sync -> reboot
|
||||
"""
|
||||
logger = logger_manager.logger
|
||||
ok, err = wifi_manager.persist_sta_credentials(router_ssid, router_password, restart_service=False)
|
||||
if not ok:
|
||||
return False, err
|
||||
|
||||
try:
|
||||
os.system("/etc/init.d/S30wifi stop")
|
||||
except Exception as e:
|
||||
logger.warning(f"[WIFI-AP] S30wifi stop: {e}")
|
||||
|
||||
ap_flag = "/boot/wifi.ap"
|
||||
sta_flag = "/boot/wifi.sta"
|
||||
try:
|
||||
if os.path.exists(ap_flag):
|
||||
os.remove(ap_flag)
|
||||
except Exception as e:
|
||||
return False, f"删除 {ap_flag} 失败: {e}"
|
||||
|
||||
try:
|
||||
with open(sta_flag, "w", encoding="utf-8") as f:
|
||||
f.write("")
|
||||
except Exception as e:
|
||||
return False, f"创建 {sta_flag} 失败: {e}"
|
||||
|
||||
try:
|
||||
os.system("sync")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("[WIFI-AP] 已切换为 STA 标志并准备 reboot")
|
||||
try:
|
||||
os.system("reboot")
|
||||
except Exception as e:
|
||||
return False, f"reboot 调用失败: {e}"
|
||||
return True, ""
|
||||
|
||||
|
||||
def _handle_client(conn, addr):
|
||||
logger = logger_manager.logger
|
||||
try:
|
||||
conn.settimeout(30.0)
|
||||
req = _read_http_request(conn)
|
||||
if not req:
|
||||
conn.sendall(_http_response("400 Bad Request", b"Bad Request"))
|
||||
return
|
||||
method, path, _headers, body = req
|
||||
path = path.split("?", 1)[0]
|
||||
|
||||
if method == "GET" and path in ("/", "/index.html"):
|
||||
conn.sendall(_http_response("200 OK", _page_form()))
|
||||
return
|
||||
|
||||
if method == "POST" and path in ("/", "/index.html"):
|
||||
try:
|
||||
qs = body.decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
qs = ""
|
||||
fields = parse_qs(qs, keep_blank_values=True)
|
||||
ssid = (fields.get("ssid") or [""])[0].strip()
|
||||
password = (fields.get("password") or [""])[0]
|
||||
ok, err = _apply_sta_and_reboot(ssid, password)
|
||||
if ok:
|
||||
msg = '<p style="color:green"><b>已保存,设备正在重启…</b></p>'
|
||||
else:
|
||||
msg = f'<p style="color:red"><b>失败:</b>{html.escape(err)}</p>'
|
||||
conn.sendall(_http_response("200 OK", _page_form(msg)))
|
||||
return
|
||||
|
||||
if method == "GET" and path == "/favicon.ico":
|
||||
conn.sendall(_http_response("204 No Content", b""))
|
||||
return
|
||||
|
||||
conn.sendall(_http_response("404 Not Found", b"Not Found"))
|
||||
except Exception as e:
|
||||
try:
|
||||
logger.error(f"[WIFI-HTTP] 处理请求异常 {addr}: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _serve_loop(host, port):
|
||||
logger = logger_manager.logger
|
||||
srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
srv.bind((host, port))
|
||||
srv.listen(5)
|
||||
srv.settimeout(1.0)
|
||||
logger.info(f"[WIFI-HTTP] 监听 {host}:{port}")
|
||||
except Exception as e:
|
||||
logger.error(f"[WIFI-HTTP] bind 失败: {e}")
|
||||
try:
|
||||
srv.close()
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
while not _http_stop.is_set():
|
||||
try:
|
||||
conn, addr = srv.accept()
|
||||
except socket.timeout:
|
||||
continue
|
||||
except Exception as e:
|
||||
if _http_stop.is_set():
|
||||
break
|
||||
logger.warning(f"[WIFI-HTTP] accept: {e}")
|
||||
continue
|
||||
t = threading.Thread(target=_handle_client, args=(conn, addr), daemon=True)
|
||||
t.start()
|
||||
|
||||
try:
|
||||
srv.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("[WIFI-HTTP] 服务已停止")
|
||||
|
||||
|
||||
def _ensure_hostapd_ssid(ssid: str, logger=None) -> bool:
|
||||
"""
|
||||
某些固件会把 SSID 写到 /etc/hostapd.conf 或 /boot/hostapd.conf。
|
||||
为避免只改 /boot/wifi.ssid 不生效,这里同步更新已存在的 hostapd.conf。
|
||||
Returns:
|
||||
bool: 任一文件被修改则 True
|
||||
"""
|
||||
if logger is None:
|
||||
logger = logger_manager.logger
|
||||
if not ssid:
|
||||
return False
|
||||
|
||||
changed_any = False
|
||||
for conf_path in ("/etc/hostapd.conf", "/boot/hostapd.conf"):
|
||||
try:
|
||||
if not os.path.exists(conf_path):
|
||||
continue
|
||||
with open(conf_path, "r", encoding="utf-8") as f:
|
||||
lines = f.read().splitlines()
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
changed = False
|
||||
out = []
|
||||
seen = False
|
||||
for ln in lines:
|
||||
s = ln.strip()
|
||||
if s.lower().startswith("ssid="):
|
||||
seen = True
|
||||
cur = s.split("=", 1)[1].strip()
|
||||
if cur != ssid:
|
||||
out.append(f"ssid={ssid}")
|
||||
changed = True
|
||||
else:
|
||||
out.append(ln)
|
||||
else:
|
||||
out.append(ln)
|
||||
if not seen:
|
||||
out.append(f"ssid={ssid}")
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
try:
|
||||
with open(conf_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(out).rstrip() + "\n")
|
||||
changed_any = True
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.warning(f"[WIFI-AP] 写入 {conf_path} 失败: {e}")
|
||||
|
||||
if changed_any and logger:
|
||||
logger.info(f"[WIFI-AP] 已同步热点 SSID 到 hostapd.conf: {ssid}")
|
||||
return changed_any
|
||||
|
||||
|
||||
def _write_boot_ap_credentials_for_s30wifi():
|
||||
"""供 S30wifi AP 分支 gen_hostapd 使用的热点 SSID/密码。"""
|
||||
base = (getattr(config, "WIFI_CONFIG_AP_SSID", "ArcherySetup") or "ArcherySetup").strip()
|
||||
# 追加设备码,便于区分多台设备(读取 /device_key,失败则不加后缀)
|
||||
suffix = ""
|
||||
try:
|
||||
with open("/device_key", "r", encoding="utf-8") as f:
|
||||
dev = (f.read() or "").strip()
|
||||
if dev:
|
||||
s = dev
|
||||
# 只保留字母数字,避免 SSID 出现不可见字符
|
||||
s = "".join([c for c in s if c.isalnum()])
|
||||
if s:
|
||||
suffix = s
|
||||
except Exception:
|
||||
suffix = ""
|
||||
ssid = f"{base}_{suffix}" if suffix else base
|
||||
pwd = getattr(config, "WIFI_CONFIG_AP_PASSWORD", "12345678")
|
||||
with open("/boot/wifi.ssid", "w", encoding="utf-8") as f:
|
||||
f.write(ssid.strip())
|
||||
with open("/boot/wifi.pass", "w", encoding="utf-8") as f:
|
||||
f.write(pwd.strip())
|
||||
try:
|
||||
_ensure_hostapd_ssid(ssid.strip())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _ensure_hostapd_modern_security(logger=None) -> bool:
|
||||
"""
|
||||
确保 AP 使用较新的安全标准(至少 WPA2-PSK + CCMP)。
|
||||
你现场验证需要的两行:
|
||||
- wpa_key_mgmt=WPA-PSK
|
||||
- rsn_pairwise=CCMP
|
||||
Returns:
|
||||
bool: 若文件被修改返回 True,否则 False
|
||||
"""
|
||||
if logger is None:
|
||||
logger = logger_manager.logger
|
||||
|
||||
conf_path = "/etc/hostapd.conf"
|
||||
try:
|
||||
if not os.path.exists(conf_path):
|
||||
return False
|
||||
with open(conf_path, "r", encoding="utf-8") as f:
|
||||
lines = f.read().splitlines()
|
||||
except Exception as e:
|
||||
logger.warning(f"[WIFI-AP] 读取 hostapd.conf 失败: {e}")
|
||||
return False
|
||||
|
||||
wanted = {
|
||||
"wpa_key_mgmt": "WPA-PSK",
|
||||
"rsn_pairwise": "CCMP",
|
||||
}
|
||||
|
||||
changed = False
|
||||
seen = set()
|
||||
new_lines = []
|
||||
for ln in lines:
|
||||
s = ln.strip()
|
||||
if not s or s.startswith("#") or "=" not in s:
|
||||
new_lines.append(ln)
|
||||
continue
|
||||
k, v = s.split("=", 1)
|
||||
k = k.strip()
|
||||
if k in wanted:
|
||||
seen.add(k)
|
||||
new_v = wanted[k]
|
||||
if v.strip() != new_v:
|
||||
new_lines.append(f"{k}={new_v}")
|
||||
changed = True
|
||||
else:
|
||||
new_lines.append(ln)
|
||||
continue
|
||||
new_lines.append(ln)
|
||||
|
||||
# 缺的补到末尾
|
||||
for k, v in wanted.items():
|
||||
if k not in seen:
|
||||
new_lines.append(f"{k}={v}")
|
||||
changed = True
|
||||
|
||||
if not changed:
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(conf_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(new_lines).rstrip() + "\n")
|
||||
logger.info("[WIFI-AP] 已更新 /etc/hostapd.conf 安全参数(WPA-PSK + CCMP)")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[WIFI-AP] 写入 hostapd.conf 失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _cleanup_ap_flag_if_needed(logger):
|
||||
"""若 /boot/wifi.ap 残留,删除它并恢复 /boot/wifi.sta,避免 main.py 误判为 AP 配网模式。"""
|
||||
ap_flag = "/boot/wifi.ap"
|
||||
sta_flag = "/boot/wifi.sta"
|
||||
if not os.path.exists(ap_flag):
|
||||
return
|
||||
try:
|
||||
os.remove(ap_flag)
|
||||
logger.info(f"[WIFI-AP] 已清理残留标记 {ap_flag}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[WIFI-AP] 清理 {ap_flag} 失败: {e}")
|
||||
return
|
||||
if not os.path.exists(sta_flag):
|
||||
try:
|
||||
with open(sta_flag, "w", encoding="utf-8") as f:
|
||||
f.write("")
|
||||
logger.info(f"[WIFI-AP] 已恢复 {sta_flag}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[WIFI-AP] 恢复 {sta_flag} 失败: {e}")
|
||||
|
||||
|
||||
def _switch_boot_to_ap_mode(logger):
|
||||
"""
|
||||
去掉 STA 标志、建立 AP 标志,由 S30wifi 起 hostapd(与 Maix start_ap 二选一,以系统脚本为准)。
|
||||
"""
|
||||
try:
|
||||
sta = "/boot/wifi.sta"
|
||||
ap = "/boot/wifi.ap"
|
||||
if os.path.exists(sta):
|
||||
os.remove(sta)
|
||||
with open(ap, "w", encoding="utf-8") as f:
|
||||
f.write("")
|
||||
os.system("/etc/init.d/S30wifi restart")
|
||||
# 某些固件生成的 hostapd.conf 缺少新安全参数,导致 Windows 提示“较旧的安全标准”。
|
||||
# 若本次修改了 hostapd.conf,则再重启一次让 hostapd 重新加载配置。
|
||||
try:
|
||||
if _ensure_hostapd_modern_security(logger):
|
||||
os.system("/etc/init.d/S30wifi restart")
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[WIFI-AP] 切换 /boot 为 AP 模式失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def start_http_server_thread():
|
||||
"""仅启动 HTTP 线程(假定 AP 已由 S30wifi 拉起)。"""
|
||||
global _http_thread
|
||||
logger = logger_manager.logger
|
||||
|
||||
if _http_thread is not None and _http_thread.is_alive():
|
||||
logger.warning("[WIFI-HTTP] 配网线程已在运行")
|
||||
return
|
||||
|
||||
_http_stop.clear()
|
||||
host = getattr(config, "WIFI_CONFIG_HTTP_HOST", "0.0.0.0")
|
||||
port = int(getattr(config, "WIFI_CONFIG_HTTP_PORT", 8080))
|
||||
|
||||
_http_thread = threading.Thread(
|
||||
target=_serve_loop,
|
||||
args=(host, port),
|
||||
daemon=True,
|
||||
name="wifi_config_httpd",
|
||||
)
|
||||
_http_thread.start()
|
||||
|
||||
|
||||
def maybe_start_wifi_ap_fallback(logger=None):
|
||||
"""
|
||||
若启用 WIFI_CONFIG_AP_FALLBACK:等待若干秒后检测 STA WiFi 与 4G,
|
||||
仅当二者均不可用时,写热点用的 /boot/wifi.ssid|pass、切到 /boot/wifi.ap 并 restart S30wifi,再启动 HTTP。
|
||||
"""
|
||||
if logger is None:
|
||||
logger = logger_manager.logger
|
||||
|
||||
if not getattr(config, "WIFI_CONFIG_AP_FALLBACK", False):
|
||||
return
|
||||
|
||||
from network import network_manager
|
||||
|
||||
# 先快速检测一次:若 STA 或 4G 已可用,直接返回,避免不必要的等待
|
||||
wifi_ok = wifi_manager.is_sta_associated()
|
||||
g4_ok = network_manager.is_4g_available()
|
||||
logger.info(f"[WIFI-AP] 兜底检测(quick):sta关联={wifi_ok}, 4g={g4_ok}")
|
||||
if wifi_ok or g4_ok:
|
||||
logger.info("[WIFI-AP] STA 或 4G 可用,不启动热点配网")
|
||||
# 清理上次开机可能残留的 /boot/wifi.ap 标记,避免 main.py 误判为 AP 配网模式
|
||||
_cleanup_ap_flag_if_needed(logger)
|
||||
return
|
||||
|
||||
# 两者均不可用:再按配置等待一段时间后复检,避免开机瞬态误判
|
||||
wait_sec = int(getattr(config, "WIFI_AP_FALLBACK_WAIT_SEC", 10))
|
||||
wait_sec = max(0, min(wait_sec, 120))
|
||||
if wait_sec > 0:
|
||||
logger.info(f"[WIFI-AP] 兜底配网:等待 {wait_sec}s 后再检测 STA/4G…")
|
||||
std_time.sleep(wait_sec)
|
||||
|
||||
# 必须用 STA 关联判断;is_wifi_connected() 在 AP 模式会因 192.168.66.1 误判为已连接
|
||||
wifi_ok = wifi_manager.is_sta_associated()
|
||||
g4_ok = network_manager.is_4g_available()
|
||||
|
||||
logger.info(f"[WIFI-AP] 兜底检测:sta关联={wifi_ok}, 4g={g4_ok}")
|
||||
|
||||
if wifi_ok or g4_ok:
|
||||
logger.info("[WIFI-AP] STA 或 4G 可用,不启动热点配网")
|
||||
_cleanup_ap_flag_if_needed(logger)
|
||||
return
|
||||
|
||||
logger.warning("[WIFI-AP] STA 与 4G 均不可用,启动热点配网(/boot/wifi.ap + HTTP)")
|
||||
|
||||
try:
|
||||
_write_boot_ap_credentials_for_s30wifi()
|
||||
except Exception as e:
|
||||
logger.error(f"[WIFI-AP] 写热点 /boot 凭证失败: {e}")
|
||||
return
|
||||
|
||||
if not _switch_boot_to_ap_mode(logger):
|
||||
return
|
||||
|
||||
std_time.sleep(3)
|
||||
start_http_server_thread()
|
||||
|
||||
p = int(getattr(config, "WIFI_CONFIG_HTTP_PORT", 8080))
|
||||
ip = getattr(config, "WIFI_CONFIG_AP_IP", "192.168.66.1")
|
||||
logger.info(f"[WIFI-AP] 请连接热点后访问 http://{ip}:{p}/ (若 IP 以 S30wifi 为准)")
|
||||
|
||||
|
||||
def stop_wifi_config_http():
|
||||
"""请求停止 HTTP 线程(下次 accept 超时后退出)。"""
|
||||
_http_stop.set()
|
||||
|
||||
|
||||
# 兼容旧名:不再使用「强制开 AP」逻辑,统一走 maybe_start_wifi_ap_fallback
|
||||
def start_wifi_config_ap_thread():
|
||||
maybe_start_wifi_ap_fallback()
|
||||
53
wpa_supplicant_conf.py
Normal file
53
wpa_supplicant_conf.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
生成 wpa_supplicant STA 配置(不经过 shell / wpa_passphrase),避免中文 SSID 在 /bin/sh 传参时被破坏。
|
||||
|
||||
与 wpa_passphrase 一致:PMK = PBKDF2-SHA1(password_utf8, ssid_utf8, 4096, 32),
|
||||
ssid 行使用 UTF-8 字节的十六进制(无引号),与 wpa_supplicant 文档一致。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
|
||||
_CTRL_HEADER = (
|
||||
"ctrl_interface=/var/run/wpa_supplicant\n"
|
||||
"update_config=1\n\n"
|
||||
)
|
||||
|
||||
|
||||
def _ssid_utf8_bytes(ssid: str) -> bytes:
|
||||
b = (ssid or "").encode("utf-8")
|
||||
if not b:
|
||||
raise ValueError("SSID 为空")
|
||||
if len(b) > 32:
|
||||
raise ValueError("SSID UTF-8 超过 32 字节")
|
||||
return b
|
||||
|
||||
|
||||
def build_sta_conf_psk(ssid: str, password: str) -> str:
|
||||
"""WPA2-PSK STA:完整 wpa_supplicant.conf 文本。"""
|
||||
ssid_b = _ssid_utf8_bytes(ssid)
|
||||
pw = (password or "").encode("utf-8")
|
||||
if len(pw) < 8 or len(pw) > 63:
|
||||
raise ValueError("WPA2-PSK 密码长度应为 8–63 字节(UTF-8)")
|
||||
pmk = hashlib.pbkdf2_hmac("sha1", pw, ssid_b, 4096, 32)
|
||||
net = (
|
||||
"network={\n"
|
||||
f"\tssid={ssid_b.hex()}\n"
|
||||
f"\tpsk={pmk.hex()}\n"
|
||||
"}\n"
|
||||
)
|
||||
return _CTRL_HEADER + net
|
||||
|
||||
|
||||
def build_sta_conf_open(ssid: str) -> str:
|
||||
"""开放网络 STA:完整 wpa_supplicant.conf 文本。"""
|
||||
ssid_b = _ssid_utf8_bytes(ssid)
|
||||
net = (
|
||||
"network={\n"
|
||||
f"\tssid={ssid_b.hex()}\n"
|
||||
"\tkey_mgmt=NONE\n"
|
||||
"}\n"
|
||||
)
|
||||
return _CTRL_HEADER + net
|
||||
Reference in New Issue
Block a user