Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config.example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ password_hash = "pbkdf2_sha256$600000$replace_me$replace_me"
session_secret = "replace-with-at-least-24-random-characters"
session_ttl_seconds = 86400
protocol_auth_enabled = true
new_connections_enabled = true
# Home Assistant/app logins use this email plus a local 6-digit PIN entered as the "code".
protocol_login_email = "you@example.com"
protocol_login_pin_hash = "pbkdf2_sha256$600000$replace_me$replace_me"
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
cloud_snapshot_path: Path | None = None,
protocol_auth_sessions_path: Path | None = None,
protocol_auth_enabled: Callable[[], bool] | None = None,
new_connections_enabled: Callable[[], bool] | None = None,
runtime_state: RuntimeState | None = None,
runtime_credentials: RuntimeCredentialsStore | None = None,
zone_ranges_store: ZoneRangesStore | None = None,
Expand All @@ -58,6 +59,7 @@ def __init__(
self.decoded_jsonl = decoded_jsonl
self.cloud_snapshot_path = cloud_snapshot_path
self._protocol_auth_enabled = protocol_auth_enabled or (lambda: True)
self._new_connections_enabled = new_connections_enabled or (lambda: True)
self.runtime_state = runtime_state
self.runtime_credentials = runtime_credentials
self.zone_ranges_store = zone_ranges_store
Expand Down Expand Up @@ -270,6 +272,8 @@ def _authorize_connect_packet_for_client(
if recovered_device is not None:
return True, "device_mqtt_recovered", info, None
if auth_reason == "unknown_device_mqtt_username":
if not self._new_connections_enabled():
return False, "new_connections_disabled", info, None
candidate = self._resolve_onboarding_device_mqtt_candidate(
client_ip=client_ip,
username=username,
Expand Down
2 changes: 2 additions & 0 deletions src/roborock_local_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class AdminConfig:
session_secret: str
session_ttl_seconds: int
protocol_auth_enabled: bool
new_connections_enabled: bool
protocol_login_email: str
protocol_login_pin_hash: str

Expand Down Expand Up @@ -247,6 +248,7 @@ def load_config(path: str | Path) -> AppConfig:
session_secret=_require_non_empty(admin.get("session_secret"), "admin.session_secret"),
session_ttl_seconds=_as_int(admin.get("session_ttl_seconds"), "admin.session_ttl_seconds", 86400),
protocol_auth_enabled=_as_bool(admin.get("protocol_auth_enabled"), True),
new_connections_enabled=_as_bool(admin.get("new_connections_enabled"), True),
protocol_login_email=_require_non_empty(admin.get("protocol_login_email"), "admin.protocol_login_email"),
protocol_login_pin_hash=_require_non_empty(
admin.get("protocol_login_pin_hash"),
Expand Down
1 change: 1 addition & 0 deletions src/roborock_local_server/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def render_config_toml(answers: ConfigureAnswers) -> str:
f"session_secret = {_toml_string(answers.session_secret)}",
"session_ttl_seconds = 86400",
"protocol_auth_enabled = true",
"new_connections_enabled = true",
f"protocol_login_email = {_toml_string(answers.protocol_login_email)}",
f"protocol_login_pin_hash = {_toml_string(answers.protocol_login_pin_hash)}",
"",
Expand Down
2 changes: 2 additions & 0 deletions src/roborock_local_server/ha_addon.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def _render_config_toml(
# The Home Assistant add-on no longer exposes this toggle.
# Keep protocol auth enabled even if a stale stored option is present.
protocol_auth_enabled = True
new_connections_enabled = True
protocol_login_email = _require_email(merged.get("protocol_login_email"), field_name="protocol_login_email")
protocol_login_pin = _require_pin(merged.get("protocol_login_pin"), field_name="protocol_login_pin")

Expand Down Expand Up @@ -272,6 +273,7 @@ def _render_config_toml(
f"session_secret = {_toml_string(admin_session_secret)}",
"session_ttl_seconds = 86400",
f"protocol_auth_enabled = {_toml_bool(protocol_auth_enabled)}",
f"new_connections_enabled = {_toml_bool(new_connections_enabled)}",
f"protocol_login_email = {_toml_string(protocol_login_email)}",
f"protocol_login_pin_hash = {_toml_string(protocol_login_pin_hash)}",
"",
Expand Down
97 changes: 97 additions & 0 deletions src/roborock_local_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,9 @@ def _require_admin(self, request: Request) -> None:
def protocol_auth_enabled(self) -> bool:
return bool(self.config.admin.protocol_auth_enabled)

def new_connections_enabled(self) -> bool:
return bool(self.config.admin.new_connections_enabled)

def _protocol_login_email(self) -> str:
return str(self.config.admin.protocol_login_email or "").strip()

Expand Down Expand Up @@ -666,6 +669,50 @@ def _is_password_reset_path(cls, clean_path: str) -> bool:
"/api/v5/user/password/email/reset",
}

@classmethod
def _is_login_flow_path(cls, clean_path: str) -> bool:
normalized = cls._normalized_path(clean_path)
return normalized in {
"/api/v1/getUrlByEmail",
"/api/v1/ml/c",
"/api/v3/key/sign",
"/api/v4/key/captcha",
} or any(
checker(normalized)
for checker in (
cls._is_code_send_path,
cls._is_code_validate_path,
cls._is_code_submit_path,
cls._is_password_login_path,
cls._is_password_reset_path,
)
)

@classmethod
def _is_onboarding_region_path(cls, clean_path: str) -> bool:
normalized = cls._normalized_path(clean_path)
return normalized.rstrip("/") in ("", "/region", "/api/region", "/b/region", "/api/b/region")

@classmethod
def _is_onboarding_nc_prepare_path(cls, clean_path: str) -> bool:
normalized = cls._normalized_path(clean_path)
return "nc" in normalized and ("prepare" in normalized or normalized.endswith("/nc"))

@classmethod
def _new_connection_flow_for_path(cls, clean_path: str) -> str | None:
normalized = cls._normalized_path(clean_path)
if (
cls._is_onboarding_region_path(normalized)
or cls._is_onboarding_nc_prepare_path(normalized)
or normalized == "/user/devices/newadd"
):
return "onboarding"
if cls._is_protocol_sync_path(normalized):
return "protocol_sync"
if cls._is_login_flow_path(normalized):
return "login"
return None

@classmethod
def _required_protocol_auth(cls, clean_path: str) -> str | None:
normalized = cls._normalized_path(clean_path)
Expand Down Expand Up @@ -730,6 +777,14 @@ def _protocol_auth_not_ready_payload(self) -> tuple[int, dict[str, Any]]:
)
return 412, payload

@staticmethod
def _new_connections_disabled_payload(flow: str) -> tuple[int, dict[str, Any]]:
return 403, {
"code": 40301,
"msg": "new_connections_disabled",
"data": {"reason": "new_connections_disabled", "flow": flow},
}

@classmethod
def _is_protocol_sync_path(cls, clean_path: str) -> bool:
return cls._normalized_path(clean_path) == PROTOCOL_AUTH_SYNC_PATH
Expand Down Expand Up @@ -947,6 +1002,38 @@ async def _handle_roborock_request(self, request: Request) -> Response:
"header_sample_added": header_sample_added,
}

blocked_flow = None if self.new_connections_enabled() else self._new_connection_flow_for_path(clean_path)
if blocked_flow is not None:
route_name = f"new_connections_disabled_{blocked_flow}"
status_code, response_payload = self._new_connections_disabled_payload(blocked_flow)
entry["route"] = route_name
entry["response_json"] = response_payload
try:
self.runtime_state.record_http_event(
event_time=str(entry["time"]),
route_name=route_name,
clean_path=clean_path,
raw_path=raw_path,
method=request.method,
host=host,
remote=str(entry["remote"]),
did=explicit_did or None,
pid=explicit_pid or None,
)
except Exception as exc: # noqa: BLE001
logger.warning("runtime_state record_http_event failed: %s", exc)
append_jsonl(self.context.http_jsonl, entry)
logger.info(
"%s %s host=%s route=%s status=%d body_sha256=%s",
request.method,
clean_path,
host or "-",
route_name,
status_code,
body_sha256[:16],
)
return JSONResponse(response_payload, status_code=status_code)

custom_sync = await self._handle_protocol_sync_route(
method=request.method,
clean_path=clean_path,
Expand Down Expand Up @@ -1304,6 +1391,7 @@ def _auth_payload(self) -> dict[str, Any]:
]
return {
"protocol_auth_enabled": self.protocol_auth_enabled(),
"new_connections_enabled": self.new_connections_enabled(),
"admin_session_secret": self.config.admin.session_secret,
"protocol_sessions": sessions,
"protocol_session_count": len(sessions),
Expand Down Expand Up @@ -1358,10 +1446,18 @@ def set_protocol_auth_enabled(self, enabled: bool) -> dict[str, Any]:
self.config = load_config(self.paths.config_file)
return self._auth_payload()

def set_new_connections_enabled(self, enabled: bool) -> dict[str, Any]:
normalized_enabled = bool(enabled)
self._rewrite_admin_bool_setting(key="new_connections_enabled", value=normalized_enabled)
self.config = load_config(self.paths.config_file)
return self._auth_payload()

def remove_protocol_session(self, *, hawk_id: str, hawk_session: str) -> bool:
return self.protocol_auth.remove_session(hawk_id=hawk_id, hawk_session=hawk_session)

def start_onboarding_session(self, *, duid: str) -> dict[str, Any]:
if not self.new_connections_enabled():
raise ValueError("New connections are disabled.")
normalized_duid = str(duid or "").strip()
if not normalized_duid:
raise ValueError("duid is required")
Expand Down Expand Up @@ -1500,6 +1596,7 @@ def _start_mqtt_proxy(self) -> None:
cloud_snapshot_path=self.paths.cloud_snapshot_path,
protocol_auth_sessions_path=self.paths.protocol_auth_sessions_path,
protocol_auth_enabled=self.protocol_auth_enabled,
new_connections_enabled=self.new_connections_enabled,
runtime_state=self.runtime_state,
runtime_credentials=self.runtime_credentials,
zone_ranges_store=self.context.zone_ranges_store,
Expand Down
51 changes: 32 additions & 19 deletions src/roborock_local_server/standalone_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def _admin_dashboard_html(project_support: dict[str, Any]) -> str:
<button id="fetchData">Fetch Data</button>
<pre id="cloudResult">No cloud request yet.</pre>
</section>
<section><h2>Protocol Auth</h2>
<label><input id="protocolAuthEnabled" type="checkbox" /> Require token/Hawk auth on protocol API routes</label>
<button id="saveAuth" style="margin-left:8px">Save</button>
<div id="authMeta" style="margin-top:8px;color:#333">Loading auth state...</div>
<section><h2>New Connections</h2>
<label><input id="newConnectionsEnabled" type="checkbox" /> Allow new app logins, onboarding, and first-time vacuum connections</label>
<button id="saveConnections" style="margin-left:8px">Save</button>
<div id="authMeta" style="margin-top:8px;color:#333">Loading connection state...</div>
<div style="margin-top:12px">
<div style="font-weight:600">Protocol Sync Secret</div>
<div style="margin-top:6px;display:flex;gap:8px;flex-wrap:wrap;align-items:center">
Expand All @@ -71,6 +71,7 @@ def _admin_dashboard_html(project_support: dict[str, Any]) -> str:
<div id="syncSecretMeta" style="margin-top:6px;color:#555">Use this with <code>mitm_redirect.py --sync-secret ...</code>.</div>
</div>
<div id="pendingRecovery" style="margin-top:8px"></div>
<div style="margin-top:12px;font-weight:600">Protocol Sessions</div>
<div id="sessionList" style="display:grid;gap:8px;margin-top:12px">Loading sessions...</div>
</section>

Expand Down Expand Up @@ -144,10 +145,10 @@ def _admin_dashboard_html(project_support: dict[str, Any]) -> str:
}}
}}
function renderAuth(auth) {{
const enabled = Boolean(auth.protocol_auth_enabled);
document.getElementById("protocolAuthEnabled").checked = enabled;
const enabled = Boolean(auth.new_connections_enabled);
document.getElementById("newConnectionsEnabled").checked = enabled;
document.getElementById("authMeta").textContent =
`Protocol auth: ${{enabled ? "Enabled" : "Disabled"}}. Persisted sessions: ${{Number(auth.protocol_session_count || 0)}}.`;
`New connections: ${{enabled ? "Allowed" : "Blocked"}}. Persisted sessions: ${{Number(auth.protocol_session_count || 0)}}.`;
const sessionSecret = String(auth.admin_session_secret || "");
document.getElementById("adminSessionSecret").value = sessionSecret;
document.getElementById("syncSecretMeta").textContent = sessionSecret
Expand Down Expand Up @@ -246,13 +247,13 @@ def _admin_dashboard_html(project_support: dict[str, Any]) -> str:
document.getElementById("cloudResult").textContent = error.message;
}}
}});
document.getElementById("saveAuth").addEventListener("click", async () => {{
document.getElementById("saveConnections").addEventListener("click", async () => {{
try {{
const payload = await fetchJson("/admin/api/auth", {{
method: "POST",
headers: {{"Content-Type":"application/json"}},
body: JSON.stringify({{
protocol_auth_enabled: document.getElementById("protocolAuthEnabled").checked
new_connections_enabled: document.getElementById("newConnectionsEnabled").checked
}})
}});
renderAuth(payload);
Expand Down Expand Up @@ -355,16 +356,28 @@ async def admin_auth_update(request: Request) -> JSONResponse:
return JSONResponse({"error": "Invalid JSON body"}, status_code=400)
if not isinstance(body, dict):
return JSONResponse({"error": "JSON body must be an object"}, status_code=400)
if "protocol_auth_enabled" not in body:
return JSONResponse({"error": "protocol_auth_enabled is required"}, status_code=400)
protocol_auth_enabled = body.get("protocol_auth_enabled")
if not isinstance(protocol_auth_enabled, bool):
return JSONResponse({"error": "protocol_auth_enabled must be a boolean"}, status_code=400)
try:
payload = supervisor.set_protocol_auth_enabled(protocol_auth_enabled)
except Exception as exc: # noqa: BLE001
return JSONResponse({"error": str(exc)}, status_code=500)
return JSONResponse(payload)
if "new_connections_enabled" in body:
new_connections_enabled = body.get("new_connections_enabled")
if not isinstance(new_connections_enabled, bool):
return JSONResponse({"error": "new_connections_enabled must be a boolean"}, status_code=400)
try:
payload = supervisor.set_new_connections_enabled(new_connections_enabled)
except Exception as exc: # noqa: BLE001
return JSONResponse({"error": str(exc)}, status_code=500)
return JSONResponse(payload)
if "protocol_auth_enabled" in body:
protocol_auth_enabled = body.get("protocol_auth_enabled")
if not isinstance(protocol_auth_enabled, bool):
return JSONResponse({"error": "protocol_auth_enabled must be a boolean"}, status_code=400)
try:
payload = supervisor.set_protocol_auth_enabled(protocol_auth_enabled)
except Exception as exc: # noqa: BLE001
return JSONResponse({"error": str(exc)}, status_code=500)
return JSONResponse(payload)
return JSONResponse(
{"error": "new_connections_enabled or protocol_auth_enabled is required"},
status_code=400,
)

@app.delete("/admin/api/auth/sessions/{hawk_id}/{hawk_session}")
async def admin_auth_delete_session(hawk_id: str, hawk_session: str, request: Request) -> JSONResponse:
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def write_release_config(
broker_mode: str = "external",
enable_topic_bridge: bool = False,
protocol_auth_enabled: bool = True,
new_connections_enabled: bool = True,
protocol_login_email: str = "user@example.com",
protocol_login_pin: str = "123456",
) -> Path:
Expand Down Expand Up @@ -58,6 +59,7 @@ def write_release_config(
session_secret = "abcdefghijklmnopqrstuvwxyz123456"
session_ttl_seconds = 3600
protocol_auth_enabled = {"true" if protocol_auth_enabled else "false"}
new_connections_enabled = {"true" if new_connections_enabled else "false"}
protocol_login_email = "{protocol_login_email}"
protocol_login_pin_hash = "{hash_password(protocol_login_pin, iterations=10_000)}"
""".strip(),
Expand Down
Loading