Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions homeassistant_api/asyncwebsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from homeassistant_api.models.config_entries import FlowResult
from homeassistant_api.models.entity_registry import EntityRegistryEntry
from homeassistant_api.models.entity_registry import EntityRegistryEntryExtended
from homeassistant_api.models.entity_registry import EntityRegistryUpdateParams
from homeassistant_api.models.entity_registry import EntityRegistryUpdateResult
from homeassistant_api.models.states import Context
from homeassistant_api.models.websocket import AuthInvalid
Expand Down Expand Up @@ -713,8 +714,7 @@ async def get_entity_registry_entry(

async def update_entity_registry_entry(
self,
entity_id: str,
**kwargs: Any,
parameters: EntityRegistryUpdateParams,
) -> EntityRegistryUpdateResult:
"""
Update an entity registry entry.
Expand All @@ -724,8 +724,7 @@ async def update_entity_registry_entry(
result = await self.recv_result_dict(
await self.send(
"config/entity_registry/update",
entity_id=entity_id,
**kwargs,
**parameters,
),
)
return EntityRegistryUpdateResult.from_json(result)
Expand Down
23 changes: 22 additions & 1 deletion homeassistant_api/models/entity_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from enum import Enum
from typing import Any
from typing import TypedDict

from pydantic import Field
from typing_extensions import NotRequired

from .base import BaseModel
from .base import DatetimeIsoField
Expand Down Expand Up @@ -59,7 +61,7 @@ class EntityRegistryEntry(BaseModel):


class EntityRegistryEntryExtended(EntityRegistryEntry):
"""Extended entity registry entry as returned by ``config/entity_registry/get`` and ``update``."""
"""Extended entity registry entry as returned by ``config/entity_registry/get``."""

aliases: list[str] = Field(default_factory=list)
capabilities: dict[str, Any] | None = None
Expand All @@ -74,3 +76,22 @@ class EntityRegistryUpdateResult(BaseModel):
entity_entry: EntityRegistryEntryExtended
reload_delay: int | None = None
require_restart: bool = False


class EntityRegistryUpdateParams(TypedDict):
"""Parameters used in ``config/entity_registry/update``."""

aliases: NotRequired[list[str]]
area_id: NotRequired[str | None]
categories: NotRequired[dict[str, str]]
device_class: NotRequired[str | None]
disabled_by: NotRequired[EntityHiddenBy | None]
entity_id: str
hidden_by: NotRequired[EntityHiddenBy | None]
icon: NotRequired[str | None]
labels: NotRequired[list[str]]
name: NotRequired[str | None]
new_entity_id: NotRequired[str]
# options and options_domain are inclusive, meaning only both or none of them have to be defined
options_domain: NotRequired[str]
options: NotRequired[dict[str, Any]]
7 changes: 3 additions & 4 deletions homeassistant_api/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from homeassistant_api.models.config_entries import FlowResult
from homeassistant_api.models.entity_registry import EntityRegistryEntry
from homeassistant_api.models.entity_registry import EntityRegistryEntryExtended
from homeassistant_api.models.entity_registry import EntityRegistryUpdateParams
from homeassistant_api.models.entity_registry import EntityRegistryUpdateResult
from homeassistant_api.models.states import Context
from homeassistant_api.models.websocket import AuthInvalid
Expand Down Expand Up @@ -686,8 +687,7 @@ def get_entity_registry_entry(self, entity_id: str) -> EntityRegistryEntryExtend

def update_entity_registry_entry(
self,
entity_id: str,
**kwargs: Any,
parameters: EntityRegistryUpdateParams,
) -> EntityRegistryUpdateResult:
"""
Update an entity registry entry.
Expand All @@ -697,8 +697,7 @@ def update_entity_registry_entry(
result = self.recv_result_dict(
self.send(
"config/entity_registry/update",
entity_id=entity_id,
**kwargs,
**parameters,
),
)
return EntityRegistryUpdateResult.from_json(result)
Expand Down
26 changes: 19 additions & 7 deletions tests/test_entity_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from homeassistant_api import WebsocketClient
from homeassistant_api.models.entity_registry import EntityRegistryEntry
from homeassistant_api.models.entity_registry import EntityRegistryEntryExtended
from homeassistant_api.models.entity_registry import EntityRegistryUpdateParams
from homeassistant_api.models.entity_registry import EntityRegistryUpdateResult

# A stable, always-present entity for read/update tests
Expand All @@ -29,14 +30,21 @@ def test_get_entity_registry_entry(websocket_client: WebsocketClient) -> None:

def test_update_entity_registry_entry(websocket_client: WebsocketClient) -> None:
result = websocket_client.update_entity_registry_entry(
_TEST_ENTITY_ID,
name="Test Name",
EntityRegistryUpdateParams(
entity_id=_TEST_ENTITY_ID,
name="Test Name",
),
)
assert isinstance(result, EntityRegistryUpdateResult)
assert result.entity_entry.entity_id == _TEST_ENTITY_ID
assert result.entity_entry.name == "Test Name"
# Restore original state
websocket_client.update_entity_registry_entry(_TEST_ENTITY_ID, name=None)
websocket_client.update_entity_registry_entry(
EntityRegistryUpdateParams(
entity_id=_TEST_ENTITY_ID,
name=None,
),
)


def test_remove_entity_registry_entry(websocket_client: WebsocketClient) -> None:
Expand Down Expand Up @@ -72,16 +80,20 @@ async def test_async_update_entity_registry_entry(
async_websocket_client: AsyncWebsocketClient,
) -> None:
result = await async_websocket_client.update_entity_registry_entry(
_TEST_ENTITY_ID,
name="Async Test Name",
EntityRegistryUpdateParams(
entity_id=_TEST_ENTITY_ID,
name="Async Test Name",
),
)
assert isinstance(result, EntityRegistryUpdateResult)
assert result.entity_entry.entity_id == _TEST_ENTITY_ID
assert result.entity_entry.name == "Async Test Name"
# Restore original state
await async_websocket_client.update_entity_registry_entry(
_TEST_ENTITY_ID,
name=None,
EntityRegistryUpdateParams(
entity_id=_TEST_ENTITY_ID,
name=None,
),
)


Expand Down