diff --git a/config.example.toml b/config.example.toml index dbf29fd..3216359 100644 --- a/config.example.toml +++ b/config.example.toml @@ -33,6 +33,7 @@ password_hash = "pbkdf2_sha256$600000$replace_me$replace_me" session_secret = "replace-with-at-least-24-random-characters" session_ttl_seconds = 86400 protocol_auth_enabled = true +new_connections_enabled = true # Home Assistant/app logins use this email plus a local 6-digit PIN entered as the "code". protocol_login_email = "you@example.com" protocol_login_pin_hash = "pbkdf2_sha256$600000$replace_me$replace_me" diff --git a/src/roborock_local_server/bundled_backend/mqtt_tls_proxy_server/server.py b/src/roborock_local_server/bundled_backend/mqtt_tls_proxy_server/server.py index 368c472..851217c 100644 --- a/src/roborock_local_server/bundled_backend/mqtt_tls_proxy_server/server.py +++ b/src/roborock_local_server/bundled_backend/mqtt_tls_proxy_server/server.py @@ -41,6 +41,7 @@ def __init__( cloud_snapshot_path: Path | None = None, protocol_auth_sessions_path: Path | None = None, protocol_auth_enabled: Callable[[], bool] | None = None, + new_connections_enabled: Callable[[], bool] | None = None, runtime_state: RuntimeState | None = None, runtime_credentials: RuntimeCredentialsStore | None = None, zone_ranges_store: ZoneRangesStore | None = None, @@ -58,6 +59,7 @@ def __init__( self.decoded_jsonl = decoded_jsonl self.cloud_snapshot_path = cloud_snapshot_path self._protocol_auth_enabled = protocol_auth_enabled or (lambda: True) + self._new_connections_enabled = new_connections_enabled or (lambda: True) self.runtime_state = runtime_state self.runtime_credentials = runtime_credentials self.zone_ranges_store = zone_ranges_store @@ -270,6 +272,8 @@ def _authorize_connect_packet_for_client( if recovered_device is not None: return True, "device_mqtt_recovered", info, None if auth_reason == "unknown_device_mqtt_username": + if not self._new_connections_enabled(): + return False, "new_connections_disabled", info, None candidate = self._resolve_onboarding_device_mqtt_candidate( client_ip=client_ip, username=username, diff --git a/src/roborock_local_server/config.py b/src/roborock_local_server/config.py index 07eeb6d..7206b2d 100644 --- a/src/roborock_local_server/config.py +++ b/src/roborock_local_server/config.py @@ -60,6 +60,7 @@ class AdminConfig: session_secret: str session_ttl_seconds: int protocol_auth_enabled: bool + new_connections_enabled: bool protocol_login_email: str protocol_login_pin_hash: str @@ -247,6 +248,7 @@ def load_config(path: str | Path) -> AppConfig: session_secret=_require_non_empty(admin.get("session_secret"), "admin.session_secret"), session_ttl_seconds=_as_int(admin.get("session_ttl_seconds"), "admin.session_ttl_seconds", 86400), protocol_auth_enabled=_as_bool(admin.get("protocol_auth_enabled"), True), + new_connections_enabled=_as_bool(admin.get("new_connections_enabled"), True), protocol_login_email=_require_non_empty(admin.get("protocol_login_email"), "admin.protocol_login_email"), protocol_login_pin_hash=_require_non_empty( admin.get("protocol_login_pin_hash"), diff --git a/src/roborock_local_server/configure.py b/src/roborock_local_server/configure.py index 703edb3..bfe5c88 100644 --- a/src/roborock_local_server/configure.py +++ b/src/roborock_local_server/configure.py @@ -359,6 +359,7 @@ def render_config_toml(answers: ConfigureAnswers) -> str: f"session_secret = {_toml_string(answers.session_secret)}", "session_ttl_seconds = 86400", "protocol_auth_enabled = true", + "new_connections_enabled = true", f"protocol_login_email = {_toml_string(answers.protocol_login_email)}", f"protocol_login_pin_hash = {_toml_string(answers.protocol_login_pin_hash)}", "", diff --git a/src/roborock_local_server/ha_addon.py b/src/roborock_local_server/ha_addon.py index 97688ef..6089dd4 100644 --- a/src/roborock_local_server/ha_addon.py +++ b/src/roborock_local_server/ha_addon.py @@ -202,6 +202,7 @@ def _render_config_toml( # The Home Assistant add-on no longer exposes this toggle. # Keep protocol auth enabled even if a stale stored option is present. protocol_auth_enabled = True + new_connections_enabled = True protocol_login_email = _require_email(merged.get("protocol_login_email"), field_name="protocol_login_email") protocol_login_pin = _require_pin(merged.get("protocol_login_pin"), field_name="protocol_login_pin") @@ -272,6 +273,7 @@ def _render_config_toml( f"session_secret = {_toml_string(admin_session_secret)}", "session_ttl_seconds = 86400", f"protocol_auth_enabled = {_toml_bool(protocol_auth_enabled)}", + f"new_connections_enabled = {_toml_bool(new_connections_enabled)}", f"protocol_login_email = {_toml_string(protocol_login_email)}", f"protocol_login_pin_hash = {_toml_string(protocol_login_pin_hash)}", "", diff --git a/src/roborock_local_server/server.py b/src/roborock_local_server/server.py index df30aed..c4a63ad 100644 --- a/src/roborock_local_server/server.py +++ b/src/roborock_local_server/server.py @@ -491,6 +491,9 @@ def _require_admin(self, request: Request) -> None: def protocol_auth_enabled(self) -> bool: return bool(self.config.admin.protocol_auth_enabled) + def new_connections_enabled(self) -> bool: + return bool(self.config.admin.new_connections_enabled) + def _protocol_login_email(self) -> str: return str(self.config.admin.protocol_login_email or "").strip() @@ -666,6 +669,50 @@ def _is_password_reset_path(cls, clean_path: str) -> bool: "/api/v5/user/password/email/reset", } + @classmethod + def _is_login_flow_path(cls, clean_path: str) -> bool: + normalized = cls._normalized_path(clean_path) + return normalized in { + "/api/v1/getUrlByEmail", + "/api/v1/ml/c", + "/api/v3/key/sign", + "/api/v4/key/captcha", + } or any( + checker(normalized) + for checker in ( + cls._is_code_send_path, + cls._is_code_validate_path, + cls._is_code_submit_path, + cls._is_password_login_path, + cls._is_password_reset_path, + ) + ) + + @classmethod + def _is_onboarding_region_path(cls, clean_path: str) -> bool: + normalized = cls._normalized_path(clean_path) + return normalized.rstrip("/") in ("", "/region", "/api/region", "/b/region", "/api/b/region") + + @classmethod + def _is_onboarding_nc_prepare_path(cls, clean_path: str) -> bool: + normalized = cls._normalized_path(clean_path) + return "nc" in normalized and ("prepare" in normalized or normalized.endswith("/nc")) + + @classmethod + def _new_connection_flow_for_path(cls, clean_path: str) -> str | None: + normalized = cls._normalized_path(clean_path) + if ( + cls._is_onboarding_region_path(normalized) + or cls._is_onboarding_nc_prepare_path(normalized) + or normalized == "/user/devices/newadd" + ): + return "onboarding" + if cls._is_protocol_sync_path(normalized): + return "protocol_sync" + if cls._is_login_flow_path(normalized): + return "login" + return None + @classmethod def _required_protocol_auth(cls, clean_path: str) -> str | None: normalized = cls._normalized_path(clean_path) @@ -730,6 +777,14 @@ def _protocol_auth_not_ready_payload(self) -> tuple[int, dict[str, Any]]: ) return 412, payload + @staticmethod + def _new_connections_disabled_payload(flow: str) -> tuple[int, dict[str, Any]]: + return 403, { + "code": 40301, + "msg": "new_connections_disabled", + "data": {"reason": "new_connections_disabled", "flow": flow}, + } + @classmethod def _is_protocol_sync_path(cls, clean_path: str) -> bool: return cls._normalized_path(clean_path) == PROTOCOL_AUTH_SYNC_PATH @@ -947,6 +1002,38 @@ async def _handle_roborock_request(self, request: Request) -> Response: "header_sample_added": header_sample_added, } + blocked_flow = None if self.new_connections_enabled() else self._new_connection_flow_for_path(clean_path) + if blocked_flow is not None: + route_name = f"new_connections_disabled_{blocked_flow}" + status_code, response_payload = self._new_connections_disabled_payload(blocked_flow) + entry["route"] = route_name + entry["response_json"] = response_payload + try: + self.runtime_state.record_http_event( + event_time=str(entry["time"]), + route_name=route_name, + clean_path=clean_path, + raw_path=raw_path, + method=request.method, + host=host, + remote=str(entry["remote"]), + did=explicit_did or None, + pid=explicit_pid or None, + ) + except Exception as exc: # noqa: BLE001 + logger.warning("runtime_state record_http_event failed: %s", exc) + append_jsonl(self.context.http_jsonl, entry) + logger.info( + "%s %s host=%s route=%s status=%d body_sha256=%s", + request.method, + clean_path, + host or "-", + route_name, + status_code, + body_sha256[:16], + ) + return JSONResponse(response_payload, status_code=status_code) + custom_sync = await self._handle_protocol_sync_route( method=request.method, clean_path=clean_path, @@ -1304,6 +1391,7 @@ def _auth_payload(self) -> dict[str, Any]: ] return { "protocol_auth_enabled": self.protocol_auth_enabled(), + "new_connections_enabled": self.new_connections_enabled(), "admin_session_secret": self.config.admin.session_secret, "protocol_sessions": sessions, "protocol_session_count": len(sessions), @@ -1358,10 +1446,18 @@ def set_protocol_auth_enabled(self, enabled: bool) -> dict[str, Any]: self.config = load_config(self.paths.config_file) return self._auth_payload() + def set_new_connections_enabled(self, enabled: bool) -> dict[str, Any]: + normalized_enabled = bool(enabled) + self._rewrite_admin_bool_setting(key="new_connections_enabled", value=normalized_enabled) + self.config = load_config(self.paths.config_file) + return self._auth_payload() + def remove_protocol_session(self, *, hawk_id: str, hawk_session: str) -> bool: return self.protocol_auth.remove_session(hawk_id=hawk_id, hawk_session=hawk_session) def start_onboarding_session(self, *, duid: str) -> dict[str, Any]: + if not self.new_connections_enabled(): + raise ValueError("New connections are disabled.") normalized_duid = str(duid or "").strip() if not normalized_duid: raise ValueError("duid is required") @@ -1500,6 +1596,7 @@ def _start_mqtt_proxy(self) -> None: cloud_snapshot_path=self.paths.cloud_snapshot_path, protocol_auth_sessions_path=self.paths.protocol_auth_sessions_path, protocol_auth_enabled=self.protocol_auth_enabled, + new_connections_enabled=self.new_connections_enabled, runtime_state=self.runtime_state, runtime_credentials=self.runtime_credentials, zone_ranges_store=self.context.zone_ranges_store, diff --git a/src/roborock_local_server/standalone_admin.py b/src/roborock_local_server/standalone_admin.py index abeb64d..4b1bfc5 100644 --- a/src/roborock_local_server/standalone_admin.py +++ b/src/roborock_local_server/standalone_admin.py @@ -58,10 +58,10 @@ def _admin_dashboard_html(project_support: dict[str, Any]) -> str:
No cloud request yet.
-

Protocol Auth

- - -
Loading auth state...
+

New Connections

+ + +
Loading connection state...
Protocol Sync Secret
@@ -71,6 +71,7 @@ def _admin_dashboard_html(project_support: dict[str, Any]) -> str:
Use this with mitm_redirect.py --sync-secret ....
+
Protocol Sessions
Loading sessions...
@@ -144,10 +145,10 @@ def _admin_dashboard_html(project_support: dict[str, Any]) -> str: }} }} function renderAuth(auth) {{ - const enabled = Boolean(auth.protocol_auth_enabled); - document.getElementById("protocolAuthEnabled").checked = enabled; + const enabled = Boolean(auth.new_connections_enabled); + document.getElementById("newConnectionsEnabled").checked = enabled; document.getElementById("authMeta").textContent = - `Protocol auth: ${{enabled ? "Enabled" : "Disabled"}}. Persisted sessions: ${{Number(auth.protocol_session_count || 0)}}.`; + `New connections: ${{enabled ? "Allowed" : "Blocked"}}. Persisted sessions: ${{Number(auth.protocol_session_count || 0)}}.`; const sessionSecret = String(auth.admin_session_secret || ""); document.getElementById("adminSessionSecret").value = sessionSecret; document.getElementById("syncSecretMeta").textContent = sessionSecret @@ -246,13 +247,13 @@ def _admin_dashboard_html(project_support: dict[str, Any]) -> str: document.getElementById("cloudResult").textContent = error.message; }} }}); - document.getElementById("saveAuth").addEventListener("click", async () => {{ + document.getElementById("saveConnections").addEventListener("click", async () => {{ try {{ const payload = await fetchJson("/admin/api/auth", {{ method: "POST", headers: {{"Content-Type":"application/json"}}, body: JSON.stringify({{ - protocol_auth_enabled: document.getElementById("protocolAuthEnabled").checked + new_connections_enabled: document.getElementById("newConnectionsEnabled").checked }}) }}); renderAuth(payload); @@ -355,16 +356,28 @@ async def admin_auth_update(request: Request) -> JSONResponse: return JSONResponse({"error": "Invalid JSON body"}, status_code=400) if not isinstance(body, dict): return JSONResponse({"error": "JSON body must be an object"}, status_code=400) - if "protocol_auth_enabled" not in body: - return JSONResponse({"error": "protocol_auth_enabled is required"}, status_code=400) - protocol_auth_enabled = body.get("protocol_auth_enabled") - if not isinstance(protocol_auth_enabled, bool): - return JSONResponse({"error": "protocol_auth_enabled must be a boolean"}, status_code=400) - try: - payload = supervisor.set_protocol_auth_enabled(protocol_auth_enabled) - except Exception as exc: # noqa: BLE001 - return JSONResponse({"error": str(exc)}, status_code=500) - return JSONResponse(payload) + if "new_connections_enabled" in body: + new_connections_enabled = body.get("new_connections_enabled") + if not isinstance(new_connections_enabled, bool): + return JSONResponse({"error": "new_connections_enabled must be a boolean"}, status_code=400) + try: + payload = supervisor.set_new_connections_enabled(new_connections_enabled) + except Exception as exc: # noqa: BLE001 + return JSONResponse({"error": str(exc)}, status_code=500) + return JSONResponse(payload) + if "protocol_auth_enabled" in body: + protocol_auth_enabled = body.get("protocol_auth_enabled") + if not isinstance(protocol_auth_enabled, bool): + return JSONResponse({"error": "protocol_auth_enabled must be a boolean"}, status_code=400) + try: + payload = supervisor.set_protocol_auth_enabled(protocol_auth_enabled) + except Exception as exc: # noqa: BLE001 + return JSONResponse({"error": str(exc)}, status_code=500) + return JSONResponse(payload) + return JSONResponse( + {"error": "new_connections_enabled or protocol_auth_enabled is required"}, + status_code=400, + ) @app.delete("/admin/api/auth/sessions/{hawk_id}/{hawk_session}") async def admin_auth_delete_session(hawk_id: str, hawk_session: str, request: Request) -> JSONResponse: diff --git a/tests/conftest.py b/tests/conftest.py index 9d1f568..5e34fc5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,7 @@ def write_release_config( broker_mode: str = "external", enable_topic_bridge: bool = False, protocol_auth_enabled: bool = True, + new_connections_enabled: bool = True, protocol_login_email: str = "user@example.com", protocol_login_pin: str = "123456", ) -> Path: @@ -58,6 +59,7 @@ def write_release_config( session_secret = "abcdefghijklmnopqrstuvwxyz123456" session_ttl_seconds = 3600 protocol_auth_enabled = {"true" if protocol_auth_enabled else "false"} +new_connections_enabled = {"true" if new_connections_enabled else "false"} protocol_login_email = "{protocol_login_email}" protocol_login_pin_hash = "{hash_password(protocol_login_pin, iterations=10_000)}" """.strip(), diff --git a/tests/test_admin_api.py b/tests/test_admin_api.py index dbc11fc..93b6124 100644 --- a/tests/test_admin_api.py +++ b/tests/test_admin_api.py @@ -242,7 +242,8 @@ def test_admin_login_and_status_flow(tmp_path: Path) -> None: dashboard_page = client.get("/admin") assert dashboard_page.status_code == 200 assert "Cloud Import" in dashboard_page.text - assert "Protocol Auth" in dashboard_page.text + assert "New Connections" in dashboard_page.text + assert "Allow new app logins, onboarding, and first-time vacuum connections" in dashboard_page.text assert "Protocol Sync Secret" in dashboard_page.text assert "Num query samples" in dashboard_page.text @@ -261,7 +262,7 @@ def test_admin_login_and_status_flow(tmp_path: Path) -> None: assert status_after_logout.status_code == 401 -def test_admin_auth_endpoints_toggle_protocol_auth_and_manage_sessions(tmp_path: Path) -> None: +def test_admin_auth_endpoints_toggle_new_connections_and_manage_sessions(tmp_path: Path) -> None: config_file = write_release_config(tmp_path) config = load_config(config_file) paths = resolve_paths(config_file, config) @@ -283,17 +284,22 @@ def test_admin_auth_endpoints_toggle_protocol_auth_and_manage_sessions(tmp_path: assert auth_payload.status_code == 200 auth_json = auth_payload.json() assert auth_json["protocol_auth_enabled"] is True + assert auth_json["new_connections_enabled"] is True assert auth_json["admin_session_secret"] == config.admin.session_secret assert auth_json["protocol_session_count"] >= 1 session = next(item for item in auth_json["protocol_sessions"] if item["hawk_id"] == issued["rriot"]["u"]) - toggled = client.post("/admin/api/auth", json={"protocol_auth_enabled": False}) + toggled = client.post("/admin/api/auth", json={"new_connections_enabled": False}) assert toggled.status_code == 200 - assert toggled.json()["protocol_auth_enabled"] is False - assert 'protocol_auth_enabled = false' in paths.config_file.read_text(encoding="utf-8") + assert toggled.json()["new_connections_enabled"] is False + assert 'new_connections_enabled = false' in paths.config_file.read_text(encoding="utf-8") - unauthed_home = client.get("/api/v1/getHomeDetail") - assert unauthed_home.status_code == 200 + blocked_login = client.post( + "/api/v5/auth/email/login/code", + json={"email": "user@example.com", "code": "123456"}, + ) + assert blocked_login.status_code == 403 + assert blocked_login.json()["msg"] == "new_connections_disabled" deleted = client.delete(f"/admin/api/auth/sessions/{session['hawk_id']}/{session['hawk_session']}") assert deleted.status_code == 200 @@ -314,9 +320,9 @@ def test_admin_auth_update_rejects_invalid_payload_types(tmp_path: Path) -> None login = client.post("/admin/api/login", json={"password": "correct horse battery staple"}) assert login.status_code == 200 - invalid_string = client.post("/admin/api/auth", json={"protocol_auth_enabled": "false"}) + invalid_string = client.post("/admin/api/auth", json={"new_connections_enabled": "false"}) assert invalid_string.status_code == 400 - assert invalid_string.json()["error"] == "protocol_auth_enabled must be a boolean" + assert invalid_string.json()["error"] == "new_connections_enabled must be a boolean" invalid_container = client.post("/admin/api/auth", json=["not-an-object"]) assert invalid_container.status_code == 400 @@ -330,6 +336,10 @@ def test_admin_auth_update_rejects_invalid_payload_types(tmp_path: Path) -> None assert invalid_json.status_code == 400 assert invalid_json.json()["error"] == "Invalid JSON body" + missing_toggle = client.post("/admin/api/auth", json={}) + assert missing_toggle.status_code == 400 + assert missing_toggle.json()["error"] == "new_connections_enabled or protocol_auth_enabled is required" + def test_set_protocol_auth_enabled_rewrites_only_exact_admin_key(tmp_path: Path) -> None: config_file = write_release_config(tmp_path) @@ -354,6 +364,29 @@ def test_set_protocol_auth_enabled_rewrites_only_exact_admin_key(tmp_path: Path) assert rendered.count("protocol_auth_enabled = false") == 1 +def test_set_new_connections_enabled_rewrites_only_exact_admin_key(tmp_path: Path) -> None: + config_file = write_release_config(tmp_path) + original = config_file.read_text(encoding="utf-8") + modified = original.replace( + "new_connections_enabled = true", + "# new_connections_enabled = true\nnew_connections_enabled_backup = true", + ) + config_file.write_text(modified, encoding="utf-8") + + config = load_config(config_file) + paths = resolve_paths(config_file, config) + supervisor = ReleaseSupervisor(config=config, paths=paths) + + payload = supervisor.set_new_connections_enabled(False) + + rendered = config_file.read_text(encoding="utf-8") + assert payload["new_connections_enabled"] is False + assert "# new_connections_enabled = true" in rendered + assert "new_connections_enabled_backup = true" in rendered + assert "new_connections_enabled = false" in rendered + assert rendered.count("new_connections_enabled = false") == 1 + + def test_admin_onboarding_endpoints_require_auth_and_manage_session(tmp_path: Path) -> None: config_file = write_release_config(tmp_path) config = load_config(config_file) @@ -457,6 +490,40 @@ def test_admin_onboarding_endpoints_require_auth_and_manage_session(tmp_path: Pa assert deleted_missing.status_code == 404 +def test_admin_onboarding_start_is_blocked_when_new_connections_disabled(tmp_path: Path) -> None: + config_file = write_release_config(tmp_path, new_connections_enabled=False) + config = load_config(config_file) + paths = resolve_paths(config_file, config) + paths.inventory_path.parent.mkdir(parents=True, exist_ok=True) + paths.inventory_path.write_text( + json.dumps( + { + "devices": [ + { + "duid": "cloud-q7-a", + "did": "1103821560705", + "name": "Q7 Upstairs", + "model": "roborock.vacuum.sc05", + "product_id": "product-q7-a", + } + ] + } + ) + + "\n", + encoding="utf-8", + ) + supervisor = ReleaseSupervisor(config=config, paths=paths) + supervisor.refresh_inventory_state() + client = TestClient(supervisor.app) + + login = client.post("/admin/api/login", json={"password": "correct horse battery staple"}) + assert login.status_code == 200 + + started = client.post("/admin/api/onboarding/sessions", json={"duid": "cloud-q7-a"}) + assert started.status_code == 400 + assert started.json()["error"] == "New connections are disabled." + + def test_core_only_mode_disables_standalone_admin_routes(tmp_path: Path) -> None: config_file = write_release_config(tmp_path) config = load_config(config_file) diff --git a/tests/test_config.py b/tests/test_config.py index 3489635..28c5bae 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -38,6 +38,7 @@ def test_load_config_and_resolve_paths(tmp_path: Path) -> None: assert config.network.https_port == 555 assert config.network.mqtt_tls_port == 8881 assert config.admin.protocol_auth_enabled is True + assert config.admin.new_connections_enabled is True assert config.admin.protocol_login_email == "user@example.com" assert paths.data_dir == (tmp_path / "data").resolve() assert paths.cert_file == (tmp_path / "certs" / "fullchain.pem").resolve() diff --git a/tests/test_configure.py b/tests/test_configure.py index 7d830fd..61e311f 100644 --- a/tests/test_configure.py +++ b/tests/test_configure.py @@ -59,6 +59,7 @@ def test_write_config_setup_embedded_cloudflare(tmp_path: Path) -> None: assert config.tls.cloudflare_token_file == "/run/secrets/cloudflare_token" assert config.tls.acme_server == "zerossl" assert config.admin.protocol_auth_enabled is True + assert config.admin.new_connections_enabled is True assert config.admin.protocol_login_email == "user@example.com" @@ -77,6 +78,7 @@ def test_write_config_setup_external_broker_requires_host_before_serve(tmp_path: assert 'host = ""' in rendered assert "port = 1883" in rendered assert "protocol_auth_enabled = true" in rendered + assert "new_connections_enabled = true" in rendered assert 'protocol_login_email = "user@example.com"' in rendered with pytest.raises(ValueError, match="broker.host is required"): diff --git a/tests/test_ha_addon.py b/tests/test_ha_addon.py index 73232f7..9c7b4d8 100644 --- a/tests/test_ha_addon.py +++ b/tests/test_ha_addon.py @@ -52,6 +52,7 @@ def test_write_config_from_home_assistant_options_provided_tls(tmp_path: Path) - assert parsed["tls"]["cert_file"] == "/ssl/fullchain.pem" assert parsed["tls"]["key_file"] == "/ssl/privkey.pem" assert parsed["admin"]["protocol_auth_enabled"] is True + assert parsed["admin"]["new_connections_enabled"] is True assert parsed["admin"]["protocol_login_email"] == "user@example.com" assert len(str(parsed["admin"]["session_secret"])) >= 24 assert str(parsed["admin"]["password_hash"]).startswith("pbkdf2_sha256$") @@ -107,6 +108,7 @@ def test_write_config_from_home_assistant_options_ignores_legacy_protocol_auth_t parsed = tomllib.loads(config_path.read_text(encoding="utf-8")) assert parsed["admin"]["protocol_auth_enabled"] is True + assert parsed["admin"]["new_connections_enabled"] is True def test_write_config_from_home_assistant_options_reuses_existing_session_secret(tmp_path: Path) -> None: diff --git a/tests/test_mqtt_tls_proxy.py b/tests/test_mqtt_tls_proxy.py index 5c966c1..73bcc39 100644 --- a/tests/test_mqtt_tls_proxy.py +++ b/tests/test_mqtt_tls_proxy.py @@ -525,6 +525,85 @@ def test_authorize_connect_accepts_unknown_device_credentials_only_for_matching_ assert rejected_candidate is None +def test_authorize_connect_rejects_onboarding_candidate_when_new_connections_disabled(tmp_path) -> None: + cloud_snapshot_path = tmp_path / "cloud_snapshot.json" + _seed_cloud_snapshot(cloud_snapshot_path) + key_state_path = tmp_path / "device_key_state.json" + _seed_key_state(key_state_path, did="1103821560705") + runtime_credentials_path = tmp_path / "runtime_credentials.json" + _write_json( + runtime_credentials_path, + { + "schema_version": 2, + "mqtt_usr": "bootstrap-user", + "mqtt_passwd": "bootstrap-pass", + "mqtt_clientid": "bootstrap-client", + "devices": [ + { + "did": "1103821560705", + "duid": "6HL2zfniaoYYV01CkVuhkO", + "name": "Roborock Qrevo MaxV 2", + "model": "roborock.vacuum.a87", + "product_id": "5gUei3OIJIXVD3eD85Balg", + "localkey": "xPd5Dr8CGGqtdDlH", + "local_key_source": "inventory", + "device_mqtt_usr": "", + "device_mqtt_pass": "", + "updated_at": "2026-04-17T17:00:00+00:00", + "last_nc_at": "", + "last_mqtt_seen_at": "", + } + ], + }, + ) + runtime_credentials = RuntimeCredentialsStore(runtime_credentials_path) + runtime_state = RuntimeState(log_dir=tmp_path, key_state_file=key_state_path, runtime_credentials=runtime_credentials) + runtime_state.upsert_vacuum("6HL2zfniaoYYV01CkVuhkO", name="Roborock Qrevo MaxV 2", id_kind="duid") + runtime_state.start_onboarding_session(target_duid="6HL2zfniaoYYV01CkVuhkO", target_name="Roborock Qrevo MaxV 2") + event_time = datetime.now(timezone.utc).isoformat() + for route_name, path_name in (("region", "/region"), ("nc_prepare", "/nc")): + runtime_state.record_http_event( + event_time=event_time, + route_name=route_name, + clean_path=path_name, + raw_path=path_name, + method="GET", + host="api-roborock.example.com", + remote="192.168.8.10:54321", + did="1103821560705", + ) + proxy = MqttTlsProxy( + cert_file=tmp_path / "fullchain.pem", + key_file=tmp_path / "privkey.pem", + listen_host="127.0.0.1", + listen_port=8883, + backend_host="127.0.0.1", + backend_port=1883, + localkey="test-local-key", + logger=logging.getLogger("test.mqtt_tls_proxy"), + decoded_jsonl=tmp_path / "decoded.jsonl", + cloud_snapshot_path=cloud_snapshot_path, + runtime_state=runtime_state, + runtime_credentials=runtime_credentials, + new_connections_enabled=lambda: False, + ) + + packet = _build_connect_packet( + client_id="a012391cb5f8bc97", + username="c25b14ceac358d2a", + password="ff8922d24a9a9af81f18f35dcee9a5a5", + ) + authorized, reason, info, candidate = proxy._authorize_connect_packet_for_client( + packet, + client_ip="192.168.8.10", + ) + + assert authorized is False + assert reason == "new_connections_disabled" + assert info is not None + assert candidate is None + + def test_trace_packet_persists_confirmed_onboarding_device_mqtt_credentials(tmp_path) -> None: cloud_snapshot_path = tmp_path / "cloud_snapshot.json" _seed_cloud_snapshot(cloud_snapshot_path) diff --git a/tests/test_protocol_auth.py b/tests/test_protocol_auth.py index 55aacf8..0247b02 100644 --- a/tests/test_protocol_auth.py +++ b/tests/test_protocol_auth.py @@ -87,8 +87,13 @@ def _build_supervisor_with_protocol_toggle( tmp_path: Path, *, protocol_auth_enabled: bool, + new_connections_enabled: bool = True, ) -> tuple[ReleaseSupervisor, object]: - config_file = write_release_config(tmp_path, protocol_auth_enabled=protocol_auth_enabled) + config_file = write_release_config( + tmp_path, + protocol_auth_enabled=protocol_auth_enabled, + new_connections_enabled=new_connections_enabled, + ) config = load_config(config_file) paths = resolve_paths(config_file, config) _write_json(paths.inventory_path, {"home": {"id": 12345, "name": "Test Home"}, "devices": []}) @@ -309,6 +314,39 @@ def test_protocol_password_login_is_rejected(tmp_path: Path) -> None: assert response.json()["msg"] == "password_login_not_supported" +def test_protocol_login_and_onboarding_routes_block_when_new_connections_disabled(tmp_path: Path) -> None: + supervisor, _paths = _build_supervisor_with_protocol_toggle( + tmp_path, + protocol_auth_enabled=True, + new_connections_enabled=False, + ) + client = TestClient(supervisor.app) + + login_response = client.post( + "/api/v5/auth/email/login/code", + json={"email": "user@example.com", "code": "123456"}, + ) + assert login_response.status_code == 403 + assert login_response.json()["msg"] == "new_connections_disabled" + assert login_response.json()["data"]["flow"] == "login" + + region_response = client.get("/region") + assert region_response.status_code == 403 + assert region_response.json()["data"]["flow"] == "onboarding" + + api_region_response = client.get("/api/region") + assert api_region_response.status_code == 403 + assert api_region_response.json()["data"]["flow"] == "onboarding" + + bare_nc_response = client.get("/nc") + assert bare_nc_response.status_code == 403 + assert bare_nc_response.json()["data"]["flow"] == "onboarding" + + newadd_response = client.get("/user/devices/newadd") + assert newadd_response.status_code == 403 + assert newadd_response.json()["data"]["flow"] == "onboarding" + + def test_protocol_sync_route_persists_additional_sessions_and_redacts_logs(tmp_path: Path) -> None: supervisor, paths = _build_supervisor(tmp_path) client = TestClient(supervisor.app)