sandbox_v2: protobuf wire + typed handlers (transport T2)

Atomic switch of the control channel from JSON dicts to typed protobuf
messages, completing transport T2 on top of T1's Transport/Codec seam.

- Codec owns the registry: each side builds a type -> (request_cls,
  result_cls) map from its own _proto mirror and constructs
  ProtobufCodec(registry). The concurrency-critical Channel core stays
  fully codec-agnostic; response frames now carry `type` so the stateless
  codec resolves the result class on both encode and decode.

- Proto refinements (locked 2026-06-03): EntityDescription wraps EntityInfo
  (identity: Description + DeviceInfo) and InitialState (state +
  capabilities + attributes); ServiceResponse is a typed envelope inside
  CallServiceResult (proto3 optional, no has_response bool); StateChanged
  is flattened and carries optional context_id; FireEvent carries optional
  context_id. Dynamic fields cross as Struct/ListValue.

- Context security model: the sandbox only ever sends a context_id string;
  parent_id / user_id never cross the wire. Main resolves the id to its own
  authoritative Context via SandboxBridge._resolve_context — reusing a
  cached Context or minting a fresh one attributed to the sandbox system
  user with no parent_id — for state_changed, fire_event and call_service.

- Generated _pb2 mirrors checked into both no-cross-import trees; regen via
  sandbox_v2/proto/generate.sh (isolated venv so the protobuf==6.32.0 pin is
  never bumped). Drift guard wired as a manual-stage prek hook that degrades
  gracefully when uv is absent.

- Default codec is protobuf (manager + runtime channel construction);
  JsonCodec is retained registry-free as the test wire for the channel-core
  tests. protobuf added to the client pyproject + the HA manifest
  requirements; grpcio-tools stays out of the project venv by design.

- ~20 handlers converted to typed messages across bridge.py, entry_runner,
  flow_runner, entity_bridge, service/event mirrors, sandbox_bridge and the
  schema bridge; ~69 test call/push sites translated with no assertion
  loosening (semantics shifts forced by proto presence are commented).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
Paulus Schoutsen
2026-06-03 08:25:58 -04:00
parent 43eb0ca426
commit 360e454330
64 changed files with 3762 additions and 1046 deletions
+14
View File
@@ -64,6 +64,17 @@ repos:
files: ^(homeassistant|tests|script)/.+\.py$
- repo: local
hooks:
# Drift guard for the checked-in sandbox_v2 protobuf gencode. Manual
# stage only (grpcio-tools is not a project dep, so it bootstraps a
# throwaway venv and degrades gracefully when uv is absent): run with
# `prek run --hook-stage manual sandbox-v2-proto-drift` or in a CI lane.
- id: sandbox-v2-proto-drift
name: sandbox_v2 protobuf gencode drift guard
entry: sandbox_v2/proto/check_drift.sh
language: script
pass_filenames: false
stages: [manual]
files: ^sandbox_v2/proto/sandbox_v2\.proto$
# Run mypy through our wrapper script in order to get the possible
# pyenv and/or virtualenv activated; it may not have been e.g. if
# committing from a GUI tool that was not launched from an activated
@@ -75,6 +86,9 @@ repos:
require_serial: true
types_or: [python, pyi]
files: ^(homeassistant|pylint)/.+\.(py|pyi)$
# Checked-in protobuf gencode (sandbox_v2): the .py + .pyi pair trips
# mypy's duplicate-module check, and it is machine-generated anyway.
exclude: _pb2\.(py|pyi)$
- id: pylint
name: pylint
entry: script/run-in-env.sh pylint --ignore-missing-annotations=y
@@ -23,6 +23,7 @@ from homeassistant.core import Event, HomeAssistant
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.typing import ConfigType
from ._proto import sandbox_v2_pb2 as pb
from .auth import async_issue_sandbox_access_token
from .bridge import SandboxBridge, async_create_bridge
from .channel import Channel
@@ -59,18 +60,17 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def _issue_token(group: str) -> str:
return await async_issue_sandbox_access_token(hass, group)
async def _on_shutdown_reply(group: str, reply: dict[str, Any]) -> None:
async def _on_shutdown_reply(group: str, reply: Any) -> None:
"""Persist the sandbox's restore-state snapshot (Phase 9).
The runtime ships its ``RestoreEntity`` state in the shutdown
reply rather than via the sandbox store bridge (the reader task
is busy dispatching the shutdown handler — a re-entrant store_save
would deadlock). We route the payload through the bridge's
store server so it lands at the same path the next run's
warm-load reads from.
reply (a ``ShutdownResult``) rather than via the sandbox store
bridge (the reader task is busy dispatching the shutdown handler —
a re-entrant store_save would deadlock). We route the payload
through the bridge's store server so it lands at the same path the
next run's warm-load reads from.
"""
restore_state = reply.get("restore_state")
if not isinstance(restore_state, dict):
if not reply.HasField("restore_state"):
return
bridge = data.bridges.get(group)
if bridge is None:
@@ -82,7 +82,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return
try:
await bridge._handle_store_save( # noqa: SLF001 — internal write path
{"key": "core.restore_state", "data": restore_state}
pb.StoreSave(key="core.restore_state", data=reply.restore_state)
)
except Exception:
_LOGGER.exception(
File diff suppressed because one or more lines are too long
@@ -0,0 +1,427 @@
from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf.internal import containers as _containers
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class Frame(_message.Message):
__slots__ = ("id", "type", "request", "response")
ID_FIELD_NUMBER: _ClassVar[int]
TYPE_FIELD_NUMBER: _ClassVar[int]
REQUEST_FIELD_NUMBER: _ClassVar[int]
RESPONSE_FIELD_NUMBER: _ClassVar[int]
id: int
type: str
request: bytes
response: Response
def __init__(self, id: _Optional[int] = ..., type: _Optional[str] = ..., request: _Optional[bytes] = ..., response: _Optional[_Union[Response, _Mapping]] = ...) -> None: ...
class Response(_message.Message):
__slots__ = ("ok", "result", "error")
OK_FIELD_NUMBER: _ClassVar[int]
RESULT_FIELD_NUMBER: _ClassVar[int]
ERROR_FIELD_NUMBER: _ClassVar[int]
ok: bool
result: bytes
error: Error
def __init__(self, ok: bool = ..., result: _Optional[bytes] = ..., error: _Optional[_Union[Error, _Mapping]] = ...) -> None: ...
class Error(_message.Message):
__slots__ = ("message", "type", "invalid", "multiple")
MESSAGE_FIELD_NUMBER: _ClassVar[int]
TYPE_FIELD_NUMBER: _ClassVar[int]
INVALID_FIELD_NUMBER: _ClassVar[int]
MULTIPLE_FIELD_NUMBER: _ClassVar[int]
message: str
type: str
invalid: _containers.RepeatedCompositeFieldContainer[InvalidError]
multiple: bool
def __init__(self, message: _Optional[str] = ..., type: _Optional[str] = ..., invalid: _Optional[_Iterable[_Union[InvalidError, _Mapping]]] = ..., multiple: bool = ...) -> None: ...
class InvalidError(_message.Message):
__slots__ = ("message", "path")
MESSAGE_FIELD_NUMBER: _ClassVar[int]
PATH_FIELD_NUMBER: _ClassVar[int]
message: str
path: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, message: _Optional[str] = ..., path: _Optional[_Iterable[str]] = ...) -> None: ...
class DevicePair(_message.Message):
__slots__ = ("key", "value")
KEY_FIELD_NUMBER: _ClassVar[int]
VALUE_FIELD_NUMBER: _ClassVar[int]
key: str
value: str
def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ...
class DeviceInfo(_message.Message):
__slots__ = ("identifiers", "connections", "via_device", "entry_type", "name", "manufacturer", "model", "model_id", "sw_version", "hw_version", "serial_number", "suggested_area", "configuration_url", "default_name", "default_manufacturer", "default_model", "translation_key")
IDENTIFIERS_FIELD_NUMBER: _ClassVar[int]
CONNECTIONS_FIELD_NUMBER: _ClassVar[int]
VIA_DEVICE_FIELD_NUMBER: _ClassVar[int]
ENTRY_TYPE_FIELD_NUMBER: _ClassVar[int]
NAME_FIELD_NUMBER: _ClassVar[int]
MANUFACTURER_FIELD_NUMBER: _ClassVar[int]
MODEL_FIELD_NUMBER: _ClassVar[int]
MODEL_ID_FIELD_NUMBER: _ClassVar[int]
SW_VERSION_FIELD_NUMBER: _ClassVar[int]
HW_VERSION_FIELD_NUMBER: _ClassVar[int]
SERIAL_NUMBER_FIELD_NUMBER: _ClassVar[int]
SUGGESTED_AREA_FIELD_NUMBER: _ClassVar[int]
CONFIGURATION_URL_FIELD_NUMBER: _ClassVar[int]
DEFAULT_NAME_FIELD_NUMBER: _ClassVar[int]
DEFAULT_MANUFACTURER_FIELD_NUMBER: _ClassVar[int]
DEFAULT_MODEL_FIELD_NUMBER: _ClassVar[int]
TRANSLATION_KEY_FIELD_NUMBER: _ClassVar[int]
identifiers: _containers.RepeatedCompositeFieldContainer[DevicePair]
connections: _containers.RepeatedCompositeFieldContainer[DevicePair]
via_device: DevicePair
entry_type: str
name: str
manufacturer: str
model: str
model_id: str
sw_version: str
hw_version: str
serial_number: str
suggested_area: str
configuration_url: str
default_name: str
default_manufacturer: str
default_model: str
translation_key: str
def __init__(self, identifiers: _Optional[_Iterable[_Union[DevicePair, _Mapping]]] = ..., connections: _Optional[_Iterable[_Union[DevicePair, _Mapping]]] = ..., via_device: _Optional[_Union[DevicePair, _Mapping]] = ..., entry_type: _Optional[str] = ..., name: _Optional[str] = ..., manufacturer: _Optional[str] = ..., model: _Optional[str] = ..., model_id: _Optional[str] = ..., sw_version: _Optional[str] = ..., hw_version: _Optional[str] = ..., serial_number: _Optional[str] = ..., suggested_area: _Optional[str] = ..., configuration_url: _Optional[str] = ..., default_name: _Optional[str] = ..., default_manufacturer: _Optional[str] = ..., default_model: _Optional[str] = ..., translation_key: _Optional[str] = ...) -> None: ...
class EntrySetup(_message.Message):
__slots__ = ("entry_id", "domain", "title", "data", "options", "source", "unique_id", "version", "minor_version")
ENTRY_ID_FIELD_NUMBER: _ClassVar[int]
DOMAIN_FIELD_NUMBER: _ClassVar[int]
TITLE_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
OPTIONS_FIELD_NUMBER: _ClassVar[int]
SOURCE_FIELD_NUMBER: _ClassVar[int]
UNIQUE_ID_FIELD_NUMBER: _ClassVar[int]
VERSION_FIELD_NUMBER: _ClassVar[int]
MINOR_VERSION_FIELD_NUMBER: _ClassVar[int]
entry_id: str
domain: str
title: str
data: _struct_pb2.Struct
options: _struct_pb2.Struct
source: str
unique_id: str
version: int
minor_version: int
def __init__(self, entry_id: _Optional[str] = ..., domain: _Optional[str] = ..., title: _Optional[str] = ..., data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., options: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., source: _Optional[str] = ..., unique_id: _Optional[str] = ..., version: _Optional[int] = ..., minor_version: _Optional[int] = ...) -> None: ...
class EntrySetupResult(_message.Message):
__slots__ = ("ok", "reason")
OK_FIELD_NUMBER: _ClassVar[int]
REASON_FIELD_NUMBER: _ClassVar[int]
ok: bool
reason: str
def __init__(self, ok: bool = ..., reason: _Optional[str] = ...) -> None: ...
class EntryUnload(_message.Message):
__slots__ = ("entry_id",)
ENTRY_ID_FIELD_NUMBER: _ClassVar[int]
entry_id: str
def __init__(self, entry_id: _Optional[str] = ...) -> None: ...
class EntryUnloadResult(_message.Message):
__slots__ = ("ok",)
OK_FIELD_NUMBER: _ClassVar[int]
ok: bool
def __init__(self, ok: bool = ...) -> None: ...
class CallService(_message.Message):
__slots__ = ("domain", "service", "target", "service_data", "context_id", "return_response")
DOMAIN_FIELD_NUMBER: _ClassVar[int]
SERVICE_FIELD_NUMBER: _ClassVar[int]
TARGET_FIELD_NUMBER: _ClassVar[int]
SERVICE_DATA_FIELD_NUMBER: _ClassVar[int]
CONTEXT_ID_FIELD_NUMBER: _ClassVar[int]
RETURN_RESPONSE_FIELD_NUMBER: _ClassVar[int]
domain: str
service: str
target: _struct_pb2.Struct
service_data: _struct_pb2.Struct
context_id: str
return_response: bool
def __init__(self, domain: _Optional[str] = ..., service: _Optional[str] = ..., target: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., service_data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., context_id: _Optional[str] = ..., return_response: bool = ...) -> None: ...
class ServiceResponse(_message.Message):
__slots__ = ("data",)
DATA_FIELD_NUMBER: _ClassVar[int]
data: _struct_pb2.Struct
def __init__(self, data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class CallServiceResult(_message.Message):
__slots__ = ("response",)
RESPONSE_FIELD_NUMBER: _ClassVar[int]
response: ServiceResponse
def __init__(self, response: _Optional[_Union[ServiceResponse, _Mapping]] = ...) -> None: ...
class Shutdown(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
class ShutdownResult(_message.Message):
__slots__ = ("ok", "unloaded", "restore_state")
OK_FIELD_NUMBER: _ClassVar[int]
UNLOADED_FIELD_NUMBER: _ClassVar[int]
RESTORE_STATE_FIELD_NUMBER: _ClassVar[int]
ok: bool
unloaded: int
restore_state: _struct_pb2.Struct
def __init__(self, ok: bool = ..., unloaded: _Optional[int] = ..., restore_state: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class Ping(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
class PingResult(_message.Message):
__slots__ = ("pong",)
PONG_FIELD_NUMBER: _ClassVar[int]
pong: str
def __init__(self, pong: _Optional[str] = ...) -> None: ...
class Ready(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
class FlowInit(_message.Message):
__slots__ = ("handler", "context", "data")
HANDLER_FIELD_NUMBER: _ClassVar[int]
CONTEXT_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
handler: str
context: _struct_pb2.Struct
data: _struct_pb2.Struct
def __init__(self, handler: _Optional[str] = ..., context: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class FlowStep(_message.Message):
__slots__ = ("flow_id", "user_input")
FLOW_ID_FIELD_NUMBER: _ClassVar[int]
USER_INPUT_FIELD_NUMBER: _ClassVar[int]
flow_id: str
user_input: _struct_pb2.Struct
def __init__(self, flow_id: _Optional[str] = ..., user_input: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class FlowAbort(_message.Message):
__slots__ = ("flow_id",)
FLOW_ID_FIELD_NUMBER: _ClassVar[int]
flow_id: str
def __init__(self, flow_id: _Optional[str] = ...) -> None: ...
class FlowAbortResult(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
class FlowResult(_message.Message):
__slots__ = ("type", "flow_id", "handler", "step_id", "reason", "title", "description", "last_step", "preview", "version", "minor_version", "data", "options", "errors", "description_placeholders", "context", "data_schema", "has_data_schema")
TYPE_FIELD_NUMBER: _ClassVar[int]
FLOW_ID_FIELD_NUMBER: _ClassVar[int]
HANDLER_FIELD_NUMBER: _ClassVar[int]
STEP_ID_FIELD_NUMBER: _ClassVar[int]
REASON_FIELD_NUMBER: _ClassVar[int]
TITLE_FIELD_NUMBER: _ClassVar[int]
DESCRIPTION_FIELD_NUMBER: _ClassVar[int]
LAST_STEP_FIELD_NUMBER: _ClassVar[int]
PREVIEW_FIELD_NUMBER: _ClassVar[int]
VERSION_FIELD_NUMBER: _ClassVar[int]
MINOR_VERSION_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
OPTIONS_FIELD_NUMBER: _ClassVar[int]
ERRORS_FIELD_NUMBER: _ClassVar[int]
DESCRIPTION_PLACEHOLDERS_FIELD_NUMBER: _ClassVar[int]
CONTEXT_FIELD_NUMBER: _ClassVar[int]
DATA_SCHEMA_FIELD_NUMBER: _ClassVar[int]
HAS_DATA_SCHEMA_FIELD_NUMBER: _ClassVar[int]
type: str
flow_id: str
handler: str
step_id: str
reason: str
title: str
description: str
last_step: bool
preview: str
version: int
minor_version: int
data: _struct_pb2.Struct
options: _struct_pb2.Struct
errors: _struct_pb2.Struct
description_placeholders: _struct_pb2.Struct
context: _struct_pb2.Struct
data_schema: _struct_pb2.ListValue
has_data_schema: bool
def __init__(self, type: _Optional[str] = ..., flow_id: _Optional[str] = ..., handler: _Optional[str] = ..., step_id: _Optional[str] = ..., reason: _Optional[str] = ..., title: _Optional[str] = ..., description: _Optional[str] = ..., last_step: bool = ..., preview: _Optional[str] = ..., version: _Optional[int] = ..., minor_version: _Optional[int] = ..., data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., options: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., errors: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., description_placeholders: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., context: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., data_schema: _Optional[_Union[_struct_pb2.ListValue, _Mapping]] = ..., has_data_schema: bool = ...) -> None: ...
class EntityInfo(_message.Message):
__slots__ = ("description", "device_info")
class Description(_message.Message):
__slots__ = ("name", "icon", "entity_category", "device_class", "supported_features", "translation_key")
NAME_FIELD_NUMBER: _ClassVar[int]
ICON_FIELD_NUMBER: _ClassVar[int]
ENTITY_CATEGORY_FIELD_NUMBER: _ClassVar[int]
DEVICE_CLASS_FIELD_NUMBER: _ClassVar[int]
SUPPORTED_FEATURES_FIELD_NUMBER: _ClassVar[int]
TRANSLATION_KEY_FIELD_NUMBER: _ClassVar[int]
name: str
icon: str
entity_category: str
device_class: str
supported_features: int
translation_key: str
def __init__(self, name: _Optional[str] = ..., icon: _Optional[str] = ..., entity_category: _Optional[str] = ..., device_class: _Optional[str] = ..., supported_features: _Optional[int] = ..., translation_key: _Optional[str] = ...) -> None: ...
DESCRIPTION_FIELD_NUMBER: _ClassVar[int]
DEVICE_INFO_FIELD_NUMBER: _ClassVar[int]
description: EntityInfo.Description
device_info: DeviceInfo
def __init__(self, description: _Optional[_Union[EntityInfo.Description, _Mapping]] = ..., device_info: _Optional[_Union[DeviceInfo, _Mapping]] = ...) -> None: ...
class InitialState(_message.Message):
__slots__ = ("state", "capabilities", "attributes")
STATE_FIELD_NUMBER: _ClassVar[int]
CAPABILITIES_FIELD_NUMBER: _ClassVar[int]
ATTRIBUTES_FIELD_NUMBER: _ClassVar[int]
state: str
capabilities: _struct_pb2.Struct
attributes: _struct_pb2.Struct
def __init__(self, state: _Optional[str] = ..., capabilities: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., attributes: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class EntityDescription(_message.Message):
__slots__ = ("entry_id", "domain", "sandbox_entity_id", "unique_id", "has_entity_name", "info", "initial")
ENTRY_ID_FIELD_NUMBER: _ClassVar[int]
DOMAIN_FIELD_NUMBER: _ClassVar[int]
SANDBOX_ENTITY_ID_FIELD_NUMBER: _ClassVar[int]
UNIQUE_ID_FIELD_NUMBER: _ClassVar[int]
HAS_ENTITY_NAME_FIELD_NUMBER: _ClassVar[int]
INFO_FIELD_NUMBER: _ClassVar[int]
INITIAL_FIELD_NUMBER: _ClassVar[int]
entry_id: str
domain: str
sandbox_entity_id: str
unique_id: str
has_entity_name: bool
info: EntityInfo
initial: InitialState
def __init__(self, entry_id: _Optional[str] = ..., domain: _Optional[str] = ..., sandbox_entity_id: _Optional[str] = ..., unique_id: _Optional[str] = ..., has_entity_name: bool = ..., info: _Optional[_Union[EntityInfo, _Mapping]] = ..., initial: _Optional[_Union[InitialState, _Mapping]] = ...) -> None: ...
class RegisterEntityResult(_message.Message):
__slots__ = ("entity_id",)
ENTITY_ID_FIELD_NUMBER: _ClassVar[int]
entity_id: str
def __init__(self, entity_id: _Optional[str] = ...) -> None: ...
class UnregisterEntity(_message.Message):
__slots__ = ("sandbox_entity_id",)
SANDBOX_ENTITY_ID_FIELD_NUMBER: _ClassVar[int]
sandbox_entity_id: str
def __init__(self, sandbox_entity_id: _Optional[str] = ...) -> None: ...
class UnregisterEntityResult(_message.Message):
__slots__ = ("ok",)
OK_FIELD_NUMBER: _ClassVar[int]
ok: bool
def __init__(self, ok: bool = ...) -> None: ...
class StateChanged(_message.Message):
__slots__ = ("sandbox_entity_id", "state", "attributes", "context_id")
SANDBOX_ENTITY_ID_FIELD_NUMBER: _ClassVar[int]
STATE_FIELD_NUMBER: _ClassVar[int]
ATTRIBUTES_FIELD_NUMBER: _ClassVar[int]
CONTEXT_ID_FIELD_NUMBER: _ClassVar[int]
sandbox_entity_id: str
state: str
attributes: _struct_pb2.Struct
context_id: str
def __init__(self, sandbox_entity_id: _Optional[str] = ..., state: _Optional[str] = ..., attributes: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., context_id: _Optional[str] = ...) -> None: ...
class RegisterService(_message.Message):
__slots__ = ("domain", "service", "supports_response", "schema")
DOMAIN_FIELD_NUMBER: _ClassVar[int]
SERVICE_FIELD_NUMBER: _ClassVar[int]
SUPPORTS_RESPONSE_FIELD_NUMBER: _ClassVar[int]
SCHEMA_FIELD_NUMBER: _ClassVar[int]
domain: str
service: str
supports_response: str
schema: _struct_pb2.ListValue
def __init__(self, domain: _Optional[str] = ..., service: _Optional[str] = ..., supports_response: _Optional[str] = ..., schema: _Optional[_Union[_struct_pb2.ListValue, _Mapping]] = ...) -> None: ...
class RegisterServiceResult(_message.Message):
__slots__ = ("ok", "installed")
OK_FIELD_NUMBER: _ClassVar[int]
INSTALLED_FIELD_NUMBER: _ClassVar[int]
ok: bool
installed: bool
def __init__(self, ok: bool = ..., installed: bool = ...) -> None: ...
class UnregisterService(_message.Message):
__slots__ = ("domain", "service")
DOMAIN_FIELD_NUMBER: _ClassVar[int]
SERVICE_FIELD_NUMBER: _ClassVar[int]
domain: str
service: str
def __init__(self, domain: _Optional[str] = ..., service: _Optional[str] = ...) -> None: ...
class UnregisterServiceResult(_message.Message):
__slots__ = ("ok", "removed")
OK_FIELD_NUMBER: _ClassVar[int]
REMOVED_FIELD_NUMBER: _ClassVar[int]
ok: bool
removed: bool
def __init__(self, ok: bool = ..., removed: bool = ...) -> None: ...
class FireEvent(_message.Message):
__slots__ = ("event_type", "event_data", "context_id")
EVENT_TYPE_FIELD_NUMBER: _ClassVar[int]
EVENT_DATA_FIELD_NUMBER: _ClassVar[int]
CONTEXT_ID_FIELD_NUMBER: _ClassVar[int]
event_type: str
event_data: _struct_pb2.Struct
context_id: str
def __init__(self, event_type: _Optional[str] = ..., event_data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., context_id: _Optional[str] = ...) -> None: ...
class StoreLoad(_message.Message):
__slots__ = ("key",)
KEY_FIELD_NUMBER: _ClassVar[int]
key: str
def __init__(self, key: _Optional[str] = ...) -> None: ...
class StoreLoadResult(_message.Message):
__slots__ = ("data",)
DATA_FIELD_NUMBER: _ClassVar[int]
data: _struct_pb2.Struct
def __init__(self, data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class StoreSave(_message.Message):
__slots__ = ("key", "data")
KEY_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
key: str
data: _struct_pb2.Struct
def __init__(self, key: _Optional[str] = ..., data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class StoreSaveResult(_message.Message):
__slots__ = ("ok",)
OK_FIELD_NUMBER: _ClassVar[int]
ok: bool
def __init__(self, ok: bool = ...) -> None: ...
class StoreRemove(_message.Message):
__slots__ = ("key",)
KEY_FIELD_NUMBER: _ClassVar[int]
key: str
def __init__(self, key: _Optional[str] = ...) -> None: ...
class StoreRemoveResult(_message.Message):
__slots__ = ("ok",)
OK_FIELD_NUMBER: _ClassVar[int]
ok: bool
def __init__(self, ok: bool = ...) -> None: ...
+204 -120
View File
@@ -43,7 +43,13 @@ from typing import Any
import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, ServiceCall, SupportsResponse, callback
from homeassistant.core import (
Context,
HomeAssistant,
ServiceCall,
SupportsResponse,
callback,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, json as json_helper
from homeassistant.helpers.entity_component import DATA_INSTANCES, EntityComponent
@@ -53,8 +59,11 @@ from homeassistant.setup import async_setup_component
from homeassistant.util import json as json_util
from homeassistant.util.file import write_utf8_file_atomic
from ._proto import sandbox_v2_pb2 as pb
from .auth import async_get_or_create_sandbox_user
from .channel import Channel, ChannelClosedError, ChannelRemoteError
from .const import UNIQUE_ID_SEPARATOR
from .messages import dict_to_struct, listvalue_to_list, struct_to_dict
from .protocol import (
MSG_CALL_SERVICE,
MSG_FIRE_EVENT,
@@ -95,23 +104,42 @@ class SandboxEntityDescription:
device_id: str | None = None
@classmethod
def from_payload(cls, payload: Mapping[str, Any]) -> SandboxEntityDescription:
"""Build a description from the wire payload."""
def from_proto(cls, msg: pb.EntityDescription) -> SandboxEntityDescription:
"""Build a description from the typed ``EntityDescription`` message.
Flattens the nested ``EntityInfo`` / ``InitialState`` sub-messages back
into the flat shape the proxy entities consume.
"""
description = msg.info.description
initial = msg.initial
device_info = (
_deserialise_device_info(msg.info.device_info)
if msg.info.HasField("device_info")
else None
)
return cls(
entry_id=payload["entry_id"],
domain=payload["domain"],
sandbox_entity_id=payload["sandbox_entity_id"],
unique_id=payload.get("unique_id"),
name=payload.get("name"),
icon=payload.get("icon"),
has_entity_name=bool(payload.get("has_entity_name", False)),
entity_category=payload.get("entity_category"),
device_class=payload.get("device_class"),
supported_features=int(payload.get("supported_features") or 0),
capabilities=dict(payload.get("capabilities") or {}),
initial_state=payload.get("initial_state"),
initial_attributes=dict(payload.get("initial_attributes") or {}),
device_info=_deserialise_device_info(payload.get("device_info")),
entry_id=msg.entry_id,
domain=msg.domain,
sandbox_entity_id=msg.sandbox_entity_id,
unique_id=msg.unique_id if msg.HasField("unique_id") else None,
name=description.name if description.HasField("name") else None,
icon=description.icon if description.HasField("icon") else None,
has_entity_name=msg.has_entity_name,
entity_category=(
description.entity_category
if description.HasField("entity_category")
else None
),
device_class=(
description.device_class
if description.HasField("device_class")
else None
),
supported_features=description.supported_features,
capabilities=struct_to_dict(initial.capabilities),
initial_state=initial.state if initial.HasField("state") else None,
initial_attributes=struct_to_dict(initial.attributes),
device_info=device_info,
)
@@ -244,6 +272,13 @@ class SandboxBridge:
self._store_server = _SandboxStoreServer(hass, group)
# Context security: the sandbox only ever sends a context_id (a
# string). Main resolves it to its own authoritative Context, never
# honouring a sandbox-supplied parent_id / user_id. Resolved contexts
# are cached so a repeated id maps to one stable Context.
self._system_user_id: str | None = None
self._contexts: dict[str, Context] = {}
channel.register(MSG_REGISTER_ENTITY, self._handle_register_entity)
channel.register(MSG_UNREGISTER_ENTITY, self._handle_unregister_entity)
channel.register(MSG_STATE_CHANGED, self._handle_state_changed)
@@ -290,17 +325,17 @@ class SandboxBridge:
return_response: bool,
) -> Any:
"""Send one ``sandbox_v2/call_service`` RPC and translate errors."""
payload: dict[str, Any] = {
"domain": domain,
"service": service,
"target": target,
"service_data": service_data,
"return_response": return_response,
}
request = pb.CallService(
domain=domain,
service=service,
target=dict_to_struct(target),
service_data=dict_to_struct(service_data),
return_response=return_response,
)
if context_id is not None:
payload["context_id"] = context_id
request.context_id = context_id
try:
return await self.channel.call(MSG_CALL_SERVICE, payload)
return await self.channel.call(MSG_CALL_SERVICE, request)
except ChannelRemoteError as err:
raise _translate_remote_error(err) from err
except ChannelClosedError as err:
@@ -308,10 +343,35 @@ class SandboxBridge:
f"Sandbox {self.group!r} channel closed mid-call"
) from err
async def _async_system_user_id(self) -> str:
"""Return (and cache) the sandbox group's system-user id."""
if self._system_user_id is None:
user = await async_get_or_create_sandbox_user(self.hass, self.group)
self._system_user_id = user.id
return self._system_user_id
async def _resolve_context(self, context_id: str | None) -> Context:
"""Resolve a sandbox-supplied context_id to an authoritative Context.
The sandbox can never set ``parent_id`` / ``user_id`` on the wire —
main owns that. A context_id main has already resolved reuses the same
Context; an unseen id (or no id) mints a fresh Context attributed to
the sandbox's system user, with no ``parent_id``.
"""
user_id = await self._async_system_user_id()
if context_id is None:
return Context(user_id=user_id)
existing = self._contexts.get(context_id)
if existing is not None:
return existing
context = Context(id=context_id, user_id=user_id)
self._contexts[context_id] = context
return context
async def _handle_register_entity(
self, payload: Mapping[str, Any]
) -> dict[str, Any]:
description = SandboxEntityDescription.from_payload(payload)
self, msg: pb.EntityDescription
) -> pb.RegisterEntityResult:
description = SandboxEntityDescription.from_proto(msg)
entry = self.hass.config_entries.async_get_entry(description.entry_id)
if entry is None:
raise HomeAssistantError(
@@ -353,12 +413,12 @@ class SandboxBridge:
existing = self._entities.get(description.sandbox_entity_id)
if existing is not None:
existing.sandbox_update_description(description)
return {"entity_id": existing.entity_id or ""}
return pb.RegisterEntityResult(entity_id=existing.entity_id or "")
proxy = self._build_proxy(description)
platform = self._ensure_platform(entry, description.domain)
await platform.async_add_entities([proxy])
self._entities[description.sandbox_entity_id] = proxy
return {"entity_id": proxy.entity_id or ""}
return pb.RegisterEntityResult(entity_id=proxy.entity_id or "")
async def _ensure_domain_loaded(self, domain: str) -> None:
"""Make sure the domain's :class:`EntityComponent` is loaded on main."""
@@ -370,36 +430,39 @@ class SandboxBridge:
await async_setup_component(self.hass, domain, {})
async def _handle_unregister_entity(
self, payload: Mapping[str, Any]
) -> dict[str, Any]:
sandbox_entity_id = payload["sandbox_entity_id"]
self, msg: pb.UnregisterEntity
) -> pb.UnregisterEntityResult:
sandbox_entity_id = msg.sandbox_entity_id
proxy = self._entities.pop(sandbox_entity_id, None)
if proxy is None:
return {"ok": True}
return pb.UnregisterEntityResult(ok=True)
entity_id = getattr(proxy, "entity_id", None)
if not entity_id:
return {"ok": True}
return pb.UnregisterEntityResult(ok=True)
domain = entity_id.split(".", 1)[0]
component: EntityComponent[Any] | None = self.hass.data.get(
DATA_INSTANCES, {}
).get(domain)
if component is not None:
await component.async_remove_entity(entity_id)
return {"ok": True}
return pb.UnregisterEntityResult(ok=True)
async def _handle_state_changed(self, payload: Mapping[str, Any]) -> None:
sandbox_entity_id = payload["sandbox_entity_id"]
proxy = self._entities.get(sandbox_entity_id)
async def _handle_state_changed(self, msg: pb.StateChanged) -> None:
proxy = self._entities.get(msg.sandbox_entity_id)
if proxy is None:
return
new_state = payload.get("new_state") or {}
state_str = new_state.get("state")
attributes = dict(new_state.get("attributes") or {})
proxy.sandbox_apply_state(state_str, attributes)
state_str = msg.state if msg.HasField("state") else None
attributes = struct_to_dict(msg.attributes)
context = (
await self._resolve_context(msg.context_id)
if msg.HasField("context_id")
else None
)
proxy.sandbox_apply_state(state_str, attributes, context)
async def _handle_register_service(
self, payload: Mapping[str, Any]
) -> dict[str, Any]:
self, msg: pb.RegisterService
) -> pb.RegisterServiceResult:
"""Mirror a sandbox-registered service onto main's service registry.
The handler that gets installed forwards every call back over
@@ -414,9 +477,9 @@ class SandboxBridge:
already owns the slot) we skip the install — the existing
handler stays in charge.
"""
domain = str(payload["domain"]).lower()
service = str(payload["service"]).lower()
supports_response = _parse_supports_response(payload.get("supports_response"))
domain = msg.domain.lower()
service = msg.service.lower()
supports_response = _parse_supports_response(msg.supports_response)
if self.hass.services.has_service(domain, service):
_LOGGER.debug(
"SandboxBridge[%s]: %s.%s already on main, not replacing",
@@ -424,10 +487,10 @@ class SandboxBridge:
domain,
service,
)
return {"ok": True, "installed": False}
return pb.RegisterServiceResult(ok=True, installed=False)
forwarder = _build_service_forwarder(self, domain, service, supports_response)
schema = reconstruct_schema(payload.get("schema"))
schema = reconstruct_schema(listvalue_to_list(msg.schema))
self.hass.services.async_register(
domain,
service,
@@ -436,52 +499,56 @@ class SandboxBridge:
supports_response=supports_response,
)
self._mirrored_services.add((domain, service))
return {"ok": True, "installed": True}
return pb.RegisterServiceResult(ok=True, installed=True)
async def _handle_unregister_service(
self, payload: Mapping[str, Any]
) -> dict[str, Any]:
domain = str(payload["domain"]).lower()
service = str(payload["service"]).lower()
self, msg: pb.UnregisterService
) -> pb.UnregisterServiceResult:
domain = msg.domain.lower()
service = msg.service.lower()
key = (domain, service)
if key not in self._mirrored_services:
return {"ok": True, "removed": False}
return pb.UnregisterServiceResult(ok=True, removed=False)
self._mirrored_services.discard(key)
if self.hass.services.has_service(domain, service):
self.hass.services.async_remove(domain, service)
return {"ok": True, "removed": True}
return pb.UnregisterServiceResult(ok=True, removed=True)
async def _handle_store_load(
self, payload: Mapping[str, Any]
) -> dict[str, Any] | None:
async def _handle_store_load(self, msg: pb.StoreLoad) -> pb.StoreLoadResult:
"""Serve a sandbox-side ``Store.async_load`` (Phase 8)."""
return await self._store_server.async_load(_require_key(payload))
data = await self._store_server.async_load(_validate_key(msg.key))
result = pb.StoreLoadResult()
if data is not None:
result.data.update(data)
return result
async def _handle_store_save(self, payload: Mapping[str, Any]) -> dict[str, Any]:
async def _handle_store_save(self, msg: pb.StoreSave) -> pb.StoreSaveResult:
"""Persist a sandbox-side ``Store.async_save`` flush (Phase 8)."""
data = payload.get("data")
if not isinstance(data, dict):
raise HomeAssistantError("store_save: missing 'data' dict")
await self._store_server.async_save(_require_key(payload), data)
return {"ok": True}
await self._store_server.async_save(
_validate_key(msg.key), struct_to_dict(msg.data)
)
return pb.StoreSaveResult(ok=True)
async def _handle_store_remove(self, payload: Mapping[str, Any]) -> dict[str, Any]:
async def _handle_store_remove(self, msg: pb.StoreRemove) -> pb.StoreRemoveResult:
"""Drop the on-disk file for a sandbox-side ``Store.async_remove``."""
await self._store_server.async_remove(_require_key(payload))
return {"ok": True}
await self._store_server.async_remove(_validate_key(msg.key))
return pb.StoreRemoveResult(ok=True)
async def _handle_fire_event(self, payload: Mapping[str, Any]) -> None:
async def _handle_fire_event(self, msg: pb.FireEvent) -> None:
"""Re-fire a sandbox-side event on main's bus.
The sandbox tags every push with ``event_type`` + ``event_data``;
the context is reconstructed minimally so listeners on main see a
consistent ``Context`` shape (the sandbox's own context id is
forwarded but not honoured by main's user resolution — that's
intentional for v2).
The sandbox tags every push with ``event_type`` + ``event_data`` and,
optionally, a ``context_id``. Main resolves that id to an
authoritative Context attributed to the sandbox's system user — the
sandbox can never inject a ``parent_id`` / ``user_id``.
"""
event_type = str(payload["event_type"])
event_data = payload.get("event_data") or {}
self.hass.bus.async_fire(event_type, dict(event_data))
event_data = struct_to_dict(msg.event_data)
context = (
await self._resolve_context(msg.context_id)
if msg.HasField("context_id")
else None
)
self.hass.bus.async_fire(msg.event_type, event_data, context=context)
def _ensure_platform(self, entry: ConfigEntry, domain: str) -> EntityPlatform:
key = (entry.entry_id, domain)
@@ -542,8 +609,8 @@ class SandboxBridge:
_STORE_KEY_FORBIDDEN = ("/", "\\", "\x00")
def _require_key(payload: Mapping[str, Any]) -> str:
"""Extract + validate a ``key`` field from a store payload.
def _validate_key(key: str) -> str:
"""Validate a store ``key`` from the wire.
Defends the host filesystem from a compromised sandbox: a key must
be a non-empty string with no path separators, no null bytes, and
@@ -551,8 +618,7 @@ def _require_key(payload: Mapping[str, Any]) -> str:
:class:`HomeAssistantError`, which the channel framework turns into
a remote-error frame for the sandbox.
"""
key = payload.get("key")
if not isinstance(key, str) or not key:
if not key:
raise HomeAssistantError("store request: missing 'key'")
if any(ch in key for ch in _STORE_KEY_FORBIDDEN):
raise HomeAssistantError(f"store request: invalid key {key!r}")
@@ -628,33 +694,50 @@ class _SandboxStoreServer:
return
def _deserialise_device_info(value: Any) -> dict[str, Any] | None:
"""Rebuild a ``DeviceInfo`` TypedDict from the wire payload.
_DEVICE_INFO_STR_FIELDS = (
"name",
"manufacturer",
"model",
"model_id",
"sw_version",
"hw_version",
"serial_number",
"suggested_area",
"configuration_url",
"default_name",
"default_manufacturer",
"default_model",
"translation_key",
)
The sandbox-side serialiser flattens sets and tuples to lists of
two-element lists; this reverses that so
:func:`device_registry.async_get_or_create` sees the shapes its
validators expect. ``entry_type`` is rebuilt as a
:class:`DeviceEntryType` enum value.
def _deserialise_device_info(info: pb.DeviceInfo) -> dict[str, Any] | None:
"""Rebuild a ``DeviceInfo`` TypedDict from the typed proto.
``identifiers`` / ``connections`` come back as sets of tuples and
``via_device`` as a tuple — the shapes
:func:`device_registry.async_get_or_create` validates. ``entry_type`` is
rebuilt as a :class:`DeviceEntryType` enum value.
"""
if not value or not isinstance(value, Mapping):
return None
out: dict[str, Any] = {}
for key, raw in value.items():
if raw is None:
out[key] = None
elif key in ("identifiers", "connections") and isinstance(raw, list):
out[key] = {tuple(item) for item in raw if isinstance(item, list)}
elif key == "via_device" and isinstance(raw, list):
out[key] = tuple(raw)
elif key == "entry_type" and isinstance(raw, str):
try:
out[key] = dr.DeviceEntryType(raw)
except ValueError:
_LOGGER.debug("register_entity: unknown entry_type %r — dropping", raw)
else:
out[key] = raw
return out
if info.identifiers:
out["identifiers"] = {(pair.key, pair.value) for pair in info.identifiers}
if info.connections:
out["connections"] = {(pair.key, pair.value) for pair in info.connections}
if info.HasField("via_device"):
out["via_device"] = (info.via_device.key, info.via_device.value)
if info.entry_type:
try:
out["entry_type"] = dr.DeviceEntryType(info.entry_type)
except ValueError:
_LOGGER.debug(
"register_entity: unknown entry_type %r — dropping", info.entry_type
)
for field_name in _DEVICE_INFO_STR_FIELDS:
value = getattr(info, field_name)
if value:
out[field_name] = value
return out or None
def _parse_supports_response(value: Any) -> SupportsResponse:
@@ -685,16 +768,17 @@ def _build_service_forwarder(
"""
async def _forward(call: ServiceCall) -> Any:
payload: dict[str, Any] = {
"domain": domain,
"service": service,
"service_data": dict(call.data),
"target": _target_from_call(call),
"return_response": call.return_response,
"context_id": call.context.id if call.context is not None else None,
}
request = pb.CallService(
domain=domain,
service=service,
service_data=dict_to_struct(dict(call.data)),
target=dict_to_struct(_target_from_call(call)),
return_response=call.return_response,
)
if call.context is not None:
request.context_id = call.context.id
try:
response = await bridge.channel.call(MSG_CALL_SERVICE, payload)
response = await bridge.channel.call(MSG_CALL_SERVICE, request)
except ChannelRemoteError as err:
raise _translate_remote_error(err) from err
except ChannelClosedError as err:
@@ -703,9 +787,9 @@ def _build_service_forwarder(
) from err
if supports_response is SupportsResponse.NONE:
return None
if isinstance(response, Mapping):
return response.get("response", response)
return response
if response.HasField("response"):
return struct_to_dict(response.response.data)
return None
return _forward
+14 -4
View File
@@ -128,9 +128,15 @@ class Frame:
return cls(FrameKind.PUSH, id=0, type=msg_type, payload=payload)
@classmethod
def ok_response(cls, call_id: int, result: Any) -> Frame:
"""Build a success response frame."""
return cls(FrameKind.RESPONSE, id=call_id, ok=True, result=result)
def ok_response(cls, call_id: int, result: Any, msg_type: str = "") -> Frame:
"""Build a success response frame.
``msg_type`` is carried so a stateless codec (the protobuf one) can
look up the result message class on encode + decode.
"""
return cls(
FrameKind.RESPONSE, id=call_id, type=msg_type, ok=True, result=result
)
@classmethod
def error_response(
@@ -139,11 +145,13 @@ class Frame:
error: str,
error_type: str | None,
error_data: dict[str, Any] | None = None,
msg_type: str = "",
) -> Frame:
"""Build a failure response frame."""
return cls(
FrameKind.RESPONSE,
id=call_id,
type=msg_type,
ok=False,
error=error,
error_type=error_type,
@@ -514,6 +522,7 @@ class Channel:
frame.id,
f"no handler for {frame.type!r}",
"ChannelUnknownType",
msg_type=frame.type,
)
)
)
@@ -566,6 +575,7 @@ class Channel:
str(err) or err.__class__.__name__,
err.__class__.__name__,
error_data_for(err),
msg_type=msg_type,
)
with contextlib.suppress(Exception):
await self._write(frame)
@@ -573,7 +583,7 @@ class Channel:
if self._closed:
return
with contextlib.suppress(Exception):
await self._write(Frame.ok_response(call_id, result))
await self._write(Frame.ok_response(call_id, result, msg_type))
__all__ = [
@@ -0,0 +1,134 @@
"""Protobuf :class:`~.channel.Codec` — the production wire.
Serialises a :class:`~.channel.Frame` to the protobuf ``Frame`` envelope and
back. The envelope carries ``type`` on responses too, so this stateless codec
can look up the result message class from ``frame.type`` on both encode and
decode — the dispatch core never has to know about proto types (the registry
lives here, not on :meth:`Channel.register`).
Mirrored verbatim across the no-cross-import boundary (the same file lives at
``hass_client.codec_protobuf``); the relative imports resolve to each side's
own :mod:`messages` + ``_proto`` gencode.
"""
from typing import Any
from google.protobuf.message import Message
from ._proto import sandbox_v2_pb2 as pb
from .channel import Frame, FrameKind
from .messages import REGISTRY
Registry = dict[str, tuple[type[Message], type[Message] | None]]
class ProtobufCodec:
"""Encode/decode :class:`Frame` objects as protobuf ``Frame`` envelopes."""
def __init__(self, registry: Registry | None = None) -> None:
"""Build the codec over a ``type → (request_cls, result_cls)`` map."""
self._registry = registry if registry is not None else REGISTRY
def _classes(
self, msg_type: str
) -> tuple[type[Message] | None, type[Message] | None]:
return self._registry.get(msg_type, (None, None))
def encode(self, frame: Frame) -> bytes:
"""Serialise a frame to the protobuf ``Frame`` envelope bytes."""
envelope = pb.Frame(id=frame.id, type=frame.type)
if frame.kind is FrameKind.RESPONSE:
response = envelope.response
response.ok = frame.ok
if frame.ok:
_, result_cls = self._classes(frame.type)
response.result = _serialize_body(frame.result, result_cls)
else:
_fill_error(response.error, frame)
else:
request_cls, _ = self._classes(frame.type)
envelope.request = _serialize_body(frame.payload, request_cls)
return envelope.SerializeToString()
def decode(self, data: bytes) -> Frame:
"""Rebuild a frame from protobuf ``Frame`` envelope bytes."""
envelope = pb.Frame.FromString(data)
msg_type = envelope.type
body = envelope.WhichOneof("body")
if body == "response":
response = envelope.response
if response.ok:
_, result_cls = self._classes(msg_type)
result = _parse_body(response.result, result_cls)
return Frame.ok_response(envelope.id, result, msg_type)
error, error_type, error_data = _read_error(response.error)
return Frame.error_response(
envelope.id, error, error_type, error_data, msg_type
)
request_cls, _ = self._classes(msg_type)
payload = _parse_body(envelope.request, request_cls)
if envelope.id == 0:
return Frame.push(msg_type, payload)
return Frame.call(envelope.id, msg_type, payload)
def _serialize_body(body: Any, cls: type[Message] | None) -> bytes:
"""Serialise a proto-message body; ``None`` becomes an empty message."""
if body is None:
return cls().SerializeToString() if cls is not None else b""
if isinstance(body, Message):
return body.SerializeToString()
raise TypeError(
f"ProtobufCodec expected a proto message body, got {type(body).__name__}"
)
def _parse_body(raw: bytes, cls: type[Message] | None) -> Any:
"""Deserialise a body into ``cls``; an unregistered type decodes to None."""
if cls is None:
return None
return cls.FromString(raw)
def _fill_error(error: pb.Error, frame: Frame) -> None:
"""Populate the proto ``Error`` from a failure frame.
Carries fidelity #7's structured voluptuous data: the ``multiple`` flag
distinguishes a ``MultipleInvalid`` from a single ``Invalid`` so the peer
rebuilds the right exception.
"""
error.message = frame.error or ""
error.type = frame.error_type or ""
data = frame.error_data
if not data:
return
if data.get("kind") == "multiple":
error.multiple = True
for child in data.get("errors", []):
error.invalid.add(message=child.get("msg", ""), path=child.get("path", []))
elif data.get("kind") == "invalid":
error.invalid.add(message=data.get("msg", ""), path=data.get("path", []))
def _read_error(error: pb.Error) -> tuple[str, str | None, dict[str, Any] | None]:
"""Rebuild ``(message, type, error_data)`` from the proto ``Error``."""
error_data: dict[str, Any] | None = None
if error.multiple:
error_data = {
"kind": "multiple",
"errors": [
{"kind": "invalid", "msg": item.message, "path": list(item.path)}
for item in error.invalid
],
}
elif len(error.invalid) == 1:
item = error.invalid[0]
error_data = {
"kind": "invalid",
"msg": item.message,
"path": list(item.path),
}
return error.message, (error.type or None), error_data
__all__ = ["ProtobufCodec"]
@@ -17,6 +17,7 @@ from enum import IntFlag
from typing import TYPE_CHECKING, Any, cast
from homeassistant.const import EntityCategory
from homeassistant.core import Context
from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.entity import Entity
@@ -117,13 +118,24 @@ class SandboxProxyEntity(Entity):
self.async_write_ha_state()
def sandbox_apply_state(
self, state: str | None, attributes: dict[str, Any]
self,
state: str | None,
attributes: dict[str, Any],
context: Context | None = None,
) -> None:
"""Update the cache from a sandbox push, and notify HA."""
"""Update the cache from a sandbox push, and notify HA.
``context`` is the main-side authoritative Context the bridge resolved
from the sandbox's ``context_id`` (attributed to the sandbox system
user, never carrying a sandbox-supplied parent_id / user_id). When
absent the entity writes with its own context as before.
"""
self._state_cache = dict(attributes)
if state is not None:
self._state_cache["state"] = state
if self.hass is not None:
if context is not None:
self.async_set_context(context)
self.async_write_ha_state()
def sandbox_set_available(self, available: bool) -> None:
@@ -3,6 +3,7 @@
from typing import Any
from homeassistant.components.button import ButtonEntity
from homeassistant.core import Context
from . import SandboxProxyEntity
@@ -12,7 +13,10 @@ class SandboxButtonEntity(SandboxProxyEntity, ButtonEntity):
"""Proxy for a ``button`` entity in a sandbox."""
def sandbox_apply_state(
self, state: str | None, attributes: dict[str, Any]
self,
state: str | None,
attributes: dict[str, Any],
context: Context | None = None,
) -> None:
"""Forward sandbox state into ButtonEntity's last-pressed field.
@@ -24,7 +28,7 @@ class SandboxButtonEntity(SandboxProxyEntity, ButtonEntity):
if state is not None:
# pylint: disable-next=attribute-defined-outside-init
self._ButtonEntity__last_pressed_isoformat = state
super().sandbox_apply_state(state, attributes)
super().sandbox_apply_state(state, attributes, context)
async def async_press(self) -> None:
"""Forward press as a ``button.press`` service call."""
@@ -3,6 +3,7 @@
from typing import Any
from homeassistant.components.event import ATTR_EVENT_TYPE, EventEntity
from homeassistant.core import Context
from homeassistant.util import dt as dt_util
from . import SandboxProxyEntity
@@ -24,7 +25,10 @@ class SandboxEventEntity(SandboxProxyEntity, EventEntity):
return list(self.description.capabilities.get("event_types") or [])
def sandbox_apply_state(
self, state: str | None, attributes: dict[str, Any]
self,
state: str | None,
attributes: dict[str, Any],
context: Context | None = None,
) -> None:
"""Replay the sandbox-side event into the EventEntity fields."""
# pylint: disable=attribute-defined-outside-init
@@ -37,4 +41,4 @@ class SandboxEventEntity(SandboxProxyEntity, EventEntity):
event_attrs = dict(attributes)
self._EventEntity__last_event_type = event_attrs.pop(ATTR_EVENT_TYPE, None)
self._EventEntity__last_event_attributes = event_attrs or None
super().sandbox_apply_state(state, attributes)
super().sandbox_apply_state(state, attributes, context)
@@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Any
from homeassistant.components.notify import NotifyEntity, NotifyEntityFeature
from homeassistant.core import Context
from . import SandboxProxyEntity
@@ -26,13 +27,16 @@ class SandboxNotifyEntity(SandboxProxyEntity, NotifyEntity):
)
def sandbox_apply_state(
self, state: str | None, attributes: dict[str, Any]
self,
state: str | None,
attributes: dict[str, Any],
context: Context | None = None,
) -> None:
"""Mirror ``__last_notified_isoformat`` for state computation."""
if state is not None:
# pylint: disable-next=attribute-defined-outside-init
self._NotifyEntity__last_notified_isoformat = state
super().sandbox_apply_state(state, attributes)
super().sandbox_apply_state(state, attributes, context)
async def async_send_message(self, message: str, title: str | None = None) -> None:
"""Forward send_message."""
@@ -8,6 +8,7 @@ covers the full set so a future classifier change doesn't surprise us.
from typing import Any
from homeassistant.components.scene import Scene
from homeassistant.core import Context
from . import SandboxProxyEntity
@@ -17,13 +18,16 @@ class SandboxSceneEntity(SandboxProxyEntity, Scene):
"""Proxy for a ``scene`` entity in a sandbox."""
def sandbox_apply_state(
self, state: str | None, attributes: dict[str, Any]
self,
state: str | None,
attributes: dict[str, Any],
context: Context | None = None,
) -> None:
"""Mirror the sandbox-side last-activated timestamp."""
if state is not None:
# pylint: disable-next=attribute-defined-outside-init
self._BaseScene__last_activated = state
super().sandbox_apply_state(state, attributes)
super().sandbox_apply_state(state, attributes, context)
async def async_activate(self, **kwargs: Any) -> None:
"""Forward activate as ``scene.turn_on``."""
@@ -20,10 +20,12 @@ from dataclasses import dataclass
import logging
import sys
import time
from typing import Any
from homeassistant.core import HomeAssistant
from .channel import Channel, ChannelClosedError, ChannelRemoteError
from .codec_protobuf import ProtobufCodec
from .protocol import MSG_READY, MSG_SHUTDOWN
_LOGGER = logging.getLogger(__name__)
@@ -36,7 +38,9 @@ DEFAULT_SHUTDOWN_GRACE = 10.0
CommandFactory = Callable[[str], list[str]]
TokenFactory = Callable[[str], Awaitable[str]]
ShutdownReplyCallback = Callable[[str, dict], Awaitable[None]]
# The reply is a protobuf ``ShutdownResult``; typed loosely to keep the
# manager free of a proto import.
ShutdownReplyCallback = Callable[[str, Any], Awaitable[None]]
class SandboxV2Error(Exception):
@@ -236,7 +240,7 @@ class SandboxProcess:
callback = self._on_shutdown_reply
if callback is not None:
try:
await callback(self.group, reply or {})
await callback(self.group, reply)
except Exception:
_LOGGER.exception(
"Sandbox %s on_shutdown_reply callback raised", self.group
@@ -378,7 +382,7 @@ class SandboxProcess:
"""
assert proc.stdout is not None
assert proc.stdin is not None
return Channel(proc.stdout, proc.stdin, name=self.group)
return Channel(proc.stdout, proc.stdin, name=self.group, codec=ProtobufCodec())
async def _drain_stream(
self, stream: asyncio.StreamReader | None, name: str
@@ -6,5 +6,6 @@
"documentation": "https://www.home-assistant.io/integrations/sandbox_v2",
"integration_type": "system",
"iot_class": "local_push",
"quality_scale": "internal"
"quality_scale": "internal",
"requirements": ["protobuf==6.32.0"]
}
@@ -0,0 +1,222 @@
"""Typed protobuf message registry + dynamic-field helpers.
This module is the codec's view of the wire: the ``type → (request_cls,
result_cls)`` registry plus the small Struct/ListValue helpers that carry the
genuinely dynamic payloads (service_data, target, state attributes,
capabilities, the wrapped Store envelope, flow ``data``/``errors``/``context``)
and the serialized voluptuous schema.
Mirrored verbatim across the no-cross-import boundary, exactly like
:mod:`channel` / :mod:`protocol`: the same file lives at
``hass_client.messages``. The relative ``._proto`` import resolves to each
side's own checked-in gencode, so the two copies are byte-identical.
Numbers note: ``google.protobuf.Struct`` stores every number as a double, so
an ``int`` that crosses inside a dynamic field comes back as a ``float``
(``255`` → ``255.0``). Python's ``==`` treats the two as equal, so dict
comparisons still hold; only an ``isinstance(x, int)`` check would notice.
Everything with integer semantics that matters (``version``, ``minor_version``,
``supported_features``) is an explicit ``int32`` field, not a Struct value.
"""
from typing import Any
from google.protobuf.message import Message
# pylint: disable-next=no-name-in-module
from google.protobuf.struct_pb2 import ListValue, Struct, Value
from ._proto import sandbox_v2_pb2 as pb
# Wire type → (request message class, result message class). The result class
# is ``None`` for one-way pushes (ready / state_changed / fire_event). The
# codec resolves these from ``frame.type`` on both encode and decode.
REGISTRY: dict[str, tuple[type[Message], type[Message] | None]] = {
# handshake (push)
"sandbox_v2/ready": (pb.Ready, None),
# main → sandbox
"sandbox_v2/entry_setup": (pb.EntrySetup, pb.EntrySetupResult),
"sandbox_v2/entry_unload": (pb.EntryUnload, pb.EntryUnloadResult),
"sandbox_v2/call_service": (pb.CallService, pb.CallServiceResult),
"sandbox_v2/shutdown": (pb.Shutdown, pb.ShutdownResult),
"sandbox_v2/ping": (pb.Ping, pb.PingResult),
"sandbox_v2/flow_init": (pb.FlowInit, pb.FlowResult),
"sandbox_v2/flow_step": (pb.FlowStep, pb.FlowResult),
"sandbox_v2/flow_abort": (pb.FlowAbort, pb.FlowAbortResult),
# sandbox → main
"sandbox_v2/register_entity": (pb.EntityDescription, pb.RegisterEntityResult),
"sandbox_v2/unregister_entity": (pb.UnregisterEntity, pb.UnregisterEntityResult),
"sandbox_v2/state_changed": (pb.StateChanged, None),
"sandbox_v2/register_service": (pb.RegisterService, pb.RegisterServiceResult),
"sandbox_v2/unregister_service": (
pb.UnregisterService,
pb.UnregisterServiceResult,
),
"sandbox_v2/fire_event": (pb.FireEvent, None),
"sandbox_v2/store_load": (pb.StoreLoad, pb.StoreLoadResult),
"sandbox_v2/store_save": (pb.StoreSave, pb.StoreSaveResult),
"sandbox_v2/store_remove": (pb.StoreRemove, pb.StoreRemoveResult),
}
# --- Struct / ListValue helpers -------------------------------------------
def _value_to_py(value: Value) -> Any:
"""Convert one ``google.protobuf.Value`` into a plain Python value."""
kind = value.WhichOneof("kind")
if kind == "null_value" or kind is None:
return None
if kind == "number_value":
return value.number_value
if kind == "string_value":
return value.string_value
if kind == "bool_value":
return value.bool_value
if kind == "struct_value":
return struct_to_dict(value.struct_value)
return [_value_to_py(item) for item in value.list_value.values]
def struct_to_dict(struct: Struct) -> dict[str, Any]:
"""Convert a ``Struct`` into a plain ``dict`` (empty Struct → ``{}``)."""
return {key: _value_to_py(val) for key, val in struct.fields.items()}
def dict_to_struct(data: dict[str, Any] | None) -> Struct:
"""Convert a ``dict`` (or ``None``) into a ``Struct``."""
struct = Struct()
if data:
struct.update(data)
return struct
def listvalue_to_list(list_value: ListValue) -> list[Any]:
"""Convert a ``ListValue`` into a plain ``list``."""
return [_value_to_py(item) for item in list_value.values]
def list_to_listvalue(items: list[Any] | None) -> ListValue:
"""Convert a ``list`` (or ``None``) into a ``ListValue``."""
list_value = ListValue()
if items:
list_value.extend(items)
return list_value
# --- DeviceInfo bridging --------------------------------------------------
# Scalar string fields of the DeviceInfo proto, copied through verbatim when
# present in the JSON-flattened device_info dict.
_DEVICE_INFO_SCALARS = (
"entry_type",
"name",
"manufacturer",
"model",
"model_id",
"sw_version",
"hw_version",
"serial_number",
"suggested_area",
"configuration_url",
"default_name",
"default_manufacturer",
"default_model",
"translation_key",
)
def device_info_to_proto(flat: dict[str, Any] | None) -> pb.DeviceInfo | None:
"""Build a ``DeviceInfo`` proto from the JSON-flattened device_info dict.
The sandbox-side serializer (``entity_bridge._serialise_device_info``)
already flattens sets/tuples/enums: ``identifiers`` / ``connections`` are
lists of two-element lists, ``via_device`` is a two-element list, and
``entry_type`` is the enum's string value. This maps that shape onto the
explicit proto fields.
"""
if not flat:
return None
info = pb.DeviceInfo()
for key, raw in flat.items():
if raw is None:
continue
if key in ("identifiers", "connections"):
for pair in raw:
if len(pair) == 2:
getattr(info, key).add(key=str(pair[0]), value=str(pair[1]))
elif key == "via_device":
if len(raw) == 2:
info.via_device.key = str(raw[0])
info.via_device.value = str(raw[1])
elif key in _DEVICE_INFO_SCALARS:
setattr(info, key, str(raw))
return info
def make_entity_description(
*,
entry_id: str,
domain: str,
sandbox_entity_id: str,
unique_id: str | None = None,
name: str | None = None,
icon: str | None = None,
has_entity_name: bool = False,
entity_category: str | None = None,
device_class: str | None = None,
supported_features: int = 0,
translation_key: str | None = None,
capabilities: dict[str, Any] | None = None,
initial_state: str | None = None,
initial_attributes: dict[str, Any] | None = None,
device_info: dict[str, Any] | None = None,
) -> pb.EntityDescription:
"""Build a nested ``EntityDescription`` proto from flat fields.
Used by the sandbox entity bridge and by tests so neither has to hand-nest
the ``EntityInfo`` / ``InitialState`` sub-messages. ``device_info`` is the
JSON-flattened dict the entity bridge produces (see
:func:`device_info_to_proto`).
"""
msg = pb.EntityDescription(
entry_id=entry_id,
domain=domain,
sandbox_entity_id=sandbox_entity_id,
has_entity_name=has_entity_name,
)
if unique_id is not None:
msg.unique_id = unique_id
description = msg.info.description
if name is not None:
description.name = name
if icon is not None:
description.icon = icon
if entity_category is not None:
description.entity_category = entity_category
if device_class is not None:
description.device_class = device_class
description.supported_features = int(supported_features or 0)
if translation_key is not None:
description.translation_key = translation_key
device = device_info_to_proto(device_info)
if device is not None:
msg.info.device_info.CopyFrom(device)
if initial_state is not None:
msg.initial.state = initial_state
if capabilities:
msg.initial.capabilities.update(capabilities)
if initial_attributes:
msg.initial.attributes.update(initial_attributes)
return msg
__all__ = [
"REGISTRY",
"device_info_to_proto",
"dict_to_struct",
"list_to_listvalue",
"listvalue_to_list",
"make_entity_description",
"struct_to_dict",
]
@@ -30,7 +30,9 @@ from typing import TYPE_CHECKING, Any
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
from homeassistant.data_entry_flow import FlowResultType
from ._proto import sandbox_v2_pb2 as pb
from .channel import ChannelClosedError, ChannelRemoteError
from .messages import dict_to_struct, listvalue_to_list, struct_to_dict
from .schema_bridge import reconstruct_schema
if TYPE_CHECKING:
@@ -105,18 +107,21 @@ class SandboxFlowProxy(ConfigFlow):
# framework's first call passes the initial data; for a
# USER source this is None. Everything else (REAUTH,
# DISCOVERY, …) gets its discovery payload here.
payload: dict[str, Any] = {
"handler": self._handler_key,
"context": dict(self.context),
"data": user_input,
}
result = await channel.call("sandbox_v2/flow_init", payload)
self._sandbox_flow_id = result.get("flow_id")
else:
result = await channel.call(
"sandbox_v2/flow_step",
{"flow_id": self._sandbox_flow_id, "user_input": user_input},
request = pb.FlowInit(
handler=self._handler_key,
context=dict_to_struct(dict(self.context)),
)
if user_input is not None:
request.data.CopyFrom(dict_to_struct(user_input))
result = await channel.call("sandbox_v2/flow_init", request)
self._sandbox_flow_id = (
result.flow_id if result.HasField("flow_id") else None
)
else:
step = pb.FlowStep(flow_id=self._sandbox_flow_id)
if user_input is not None:
step.user_input.CopyFrom(dict_to_struct(user_input))
result = await channel.call("sandbox_v2/flow_step", step)
except ChannelClosedError:
self._terminated = True
_LOGGER.warning(
@@ -139,7 +144,7 @@ class SandboxFlowProxy(ConfigFlow):
await self._apply_remote_context(result)
return self._adapt_result(result, step_id)
async def _apply_remote_context(self, result: dict[str, Any]) -> None:
async def _apply_remote_context(self, result: pb.FlowResult) -> None:
"""Mirror ``unique_id`` (and other context bits) onto our own flow.
The sandbox's :meth:`ConfigFlow.async_set_unique_id` mutates the
@@ -149,9 +154,9 @@ class SandboxFlowProxy(ConfigFlow):
(it raises :class:`AbortFlow` for an in-progress collision,
which the flow framework turns into an ABORT result).
"""
remote = result.get("context")
if not isinstance(remote, dict):
if not result.HasField("context"):
return
remote = struct_to_dict(result.context)
if "unique_id" not in remote:
return
unique_id = remote["unique_id"]
@@ -162,24 +167,35 @@ class SandboxFlowProxy(ConfigFlow):
# id; that's exactly the duplicate-rejection signal we want.
await self.async_set_unique_id(unique_id)
def _adapt_result(self, result: dict[str, Any], step_id: str) -> ConfigFlowResult:
"""Translate a sandbox-side FlowResult dict into a main-side one.
def _adapt_result(self, result: pb.FlowResult, step_id: str) -> ConfigFlowResult:
"""Translate a sandbox-side ``FlowResult`` message into a main-side one.
The sandbox's ``flow_id`` and ``handler`` are replaced with main's
view (so HA's frontend / FlowManager keep tracking the proxy
flow), and CREATE_ENTRY data is tagged with the sandbox group so
the setup interceptor knows where to route the entry.
"""
result_type = FlowResultType(result["type"])
result_type = FlowResultType(result.type)
placeholders = (
struct_to_dict(result.description_placeholders)
if result.HasField("description_placeholders")
else None
)
if result_type is FlowResultType.CREATE_ENTRY:
entry_data = dict(result.get("data") or {})
entry_data = struct_to_dict(result.data)
self._terminated = True
create_result = self.async_create_entry(
title=result.get("title") or self._handler_key,
title=(
result.title
if result.HasField("title") and result.title
else self._handler_key
),
data=entry_data,
description=result.get("description"),
description_placeholders=result.get("description_placeholders"),
description=(
result.description if result.HasField("description") else None
),
description_placeholders=placeholders,
)
# Tag the FlowResult so the framework's entry constructor in
# ``ConfigEntriesFlowManager.async_finish_flow`` reads it into
@@ -191,25 +207,30 @@ class SandboxFlowProxy(ConfigFlow):
if result_type is FlowResultType.ABORT:
self._terminated = True
return self.async_abort(
reason=result.get("reason", "sandbox_aborted"),
description_placeholders=result.get("description_placeholders"),
reason=(
result.reason if result.HasField("reason") else "sandbox_aborted"
),
description_placeholders=placeholders,
)
if result_type is FlowResultType.FORM:
data_schema = reconstruct_schema(result.get("data_schema"))
if data_schema is None and result.get("_has_data_schema"):
data_schema = reconstruct_schema(listvalue_to_list(result.data_schema))
if data_schema is None and result.has_data_schema:
_LOGGER.debug(
"Sandbox %r returned a FORM with an unserialisable"
" data_schema; rendering schema-less",
self._sandbox_group,
)
errors = (
struct_to_dict(result.errors) if result.HasField("errors") else None
)
return self.async_show_form(
step_id=result.get("step_id", step_id),
step_id=result.step_id if result.HasField("step_id") else step_id,
data_schema=data_schema,
errors=result.get("errors") or None,
description_placeholders=result.get("description_placeholders"),
last_step=result.get("last_step"),
preview=result.get("preview"),
errors=errors or None,
description_placeholders=placeholders,
last_step=result.last_step if result.HasField("last_step") else None,
preview=result.preview if result.HasField("preview") else None,
)
# Any other type (MENU, EXTERNAL_STEP, SHOW_PROGRESS, …) is
@@ -255,7 +276,7 @@ class SandboxFlowProxy(ConfigFlow):
async def _safe_abort(channel: Any, flow_id: str, group: str, handler: str) -> None:
"""Fire ``flow_abort`` on the sandbox and swallow errors."""
try:
await channel.call("sandbox_v2/flow_abort", {"flow_id": flow_id})
await channel.call("sandbox_v2/flow_abort", pb.FlowAbort(flow_id=flow_id))
except (ChannelClosedError, ChannelRemoteError) as err:
_LOGGER.debug("Sandbox %r flow_abort for %s failed: %s", group, handler, err)
+23 -17
View File
@@ -28,9 +28,11 @@ from homeassistant.config_entries import (
from homeassistant.core import HomeAssistant
from homeassistant.loader import async_get_integration
from ._proto import sandbox_v2_pb2 as pb
from .channel import ChannelClosedError, ChannelRemoteError
from .classifier import SandboxAssignment, classify
from .manager import SandboxManager
from .messages import dict_to_struct
from .protocol import MSG_ENTRY_SETUP, MSG_ENTRY_UNLOAD
from .proxy_flow import SandboxFlowProxy
@@ -129,8 +131,10 @@ class SandboxFlowRouter:
)
return False
if not result.get("ok"):
reason = result.get("reason") or "sandbox refused setup"
if not result.ok:
reason = (
result.reason if result.HasField("reason") else "sandbox refused setup"
)
entry._async_set_state( # noqa: SLF001
self._hass, ConfigEntryState.SETUP_ERROR, reason
)
@@ -153,7 +157,7 @@ class SandboxFlowRouter:
return True
try:
result = await sandbox.channel.call(
MSG_ENTRY_UNLOAD, {"entry_id": entry.entry_id}
MSG_ENTRY_UNLOAD, pb.EntryUnload(entry_id=entry.entry_id)
)
except ChannelClosedError, ChannelRemoteError:
_LOGGER.exception(
@@ -167,7 +171,7 @@ class SandboxFlowRouter:
bridge = self._data.bridges.get(group)
if bridge is not None:
await bridge.async_unload_entry(entry)
return bool(result.get("ok", True))
return result.ok
async def _assignment_for_new_flow(self, handler_key: str) -> SandboxAssignment:
"""Decide where a new flow for ``handler_key`` should run.
@@ -183,23 +187,25 @@ class SandboxFlowRouter:
return classify(integration)
def _entry_setup_payload(entry: ConfigEntry) -> dict[str, Any]:
"""Build the wire payload for ``sandbox_v2/entry_setup``.
def _entry_setup_payload(entry: ConfigEntry) -> pb.EntrySetup:
"""Build the typed ``EntrySetup`` message for ``sandbox_v2/entry_setup``.
Surfaces the small subset of entry fields the integration's
``async_setup_entry`` reads.
"""
return {
"entry_id": entry.entry_id,
"domain": entry.domain,
"title": entry.title,
"data": dict(entry.data),
"options": dict(entry.options),
"source": entry.source,
"unique_id": entry.unique_id,
"version": entry.version,
"minor_version": entry.minor_version,
}
msg = pb.EntrySetup(
entry_id=entry.entry_id,
domain=entry.domain,
title=entry.title,
data=dict_to_struct(dict(entry.data)),
options=dict_to_struct(dict(entry.options)),
source=entry.source,
version=entry.version,
minor_version=entry.minor_version,
)
if entry.unique_id is not None:
msg.unique_id = entry.unique_id
return msg
__all__ = ["SandboxFlowRouter"]
+5
View File
@@ -102,6 +102,8 @@ include = ["homeassistant*"]
[tool.pylint.MAIN]
py-version = "3.14"
# Checked-in protobuf gencode (sandbox_v2) is machine-generated — never lint it.
ignore-paths = [".*_pb2\\.pyi?$"]
# Use a conservative default here; 2 should speed up most setups and not hurt
# any too bad. Override on command line as appropriate.
jobs = 2
@@ -649,6 +651,9 @@ exclude_lines = [
[tool.ruff]
required-version = ">=0.15.13"
# Checked-in protobuf gencode (sandbox_v2) — machine-generated, regenerated by
# sandbox_v2/proto/generate.sh; never hand-edited, so never linted.
extend-exclude = ["*_pb2.py", "*_pb2.pyi"]
[tool.ruff.lint]
select = [
+3
View File
@@ -1860,6 +1860,9 @@ proliphix==0.4.1
# homeassistant.components.prometheus
prometheus-client==0.21.0
# homeassistant.components.sandbox_v2
protobuf==6.32.0
# homeassistant.components.prowl
prowlpy==1.1.5
File diff suppressed because one or more lines are too long
@@ -0,0 +1,427 @@
from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf.internal import containers as _containers
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class Frame(_message.Message):
__slots__ = ("id", "type", "request", "response")
ID_FIELD_NUMBER: _ClassVar[int]
TYPE_FIELD_NUMBER: _ClassVar[int]
REQUEST_FIELD_NUMBER: _ClassVar[int]
RESPONSE_FIELD_NUMBER: _ClassVar[int]
id: int
type: str
request: bytes
response: Response
def __init__(self, id: _Optional[int] = ..., type: _Optional[str] = ..., request: _Optional[bytes] = ..., response: _Optional[_Union[Response, _Mapping]] = ...) -> None: ...
class Response(_message.Message):
__slots__ = ("ok", "result", "error")
OK_FIELD_NUMBER: _ClassVar[int]
RESULT_FIELD_NUMBER: _ClassVar[int]
ERROR_FIELD_NUMBER: _ClassVar[int]
ok: bool
result: bytes
error: Error
def __init__(self, ok: bool = ..., result: _Optional[bytes] = ..., error: _Optional[_Union[Error, _Mapping]] = ...) -> None: ...
class Error(_message.Message):
__slots__ = ("message", "type", "invalid", "multiple")
MESSAGE_FIELD_NUMBER: _ClassVar[int]
TYPE_FIELD_NUMBER: _ClassVar[int]
INVALID_FIELD_NUMBER: _ClassVar[int]
MULTIPLE_FIELD_NUMBER: _ClassVar[int]
message: str
type: str
invalid: _containers.RepeatedCompositeFieldContainer[InvalidError]
multiple: bool
def __init__(self, message: _Optional[str] = ..., type: _Optional[str] = ..., invalid: _Optional[_Iterable[_Union[InvalidError, _Mapping]]] = ..., multiple: bool = ...) -> None: ...
class InvalidError(_message.Message):
__slots__ = ("message", "path")
MESSAGE_FIELD_NUMBER: _ClassVar[int]
PATH_FIELD_NUMBER: _ClassVar[int]
message: str
path: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, message: _Optional[str] = ..., path: _Optional[_Iterable[str]] = ...) -> None: ...
class DevicePair(_message.Message):
__slots__ = ("key", "value")
KEY_FIELD_NUMBER: _ClassVar[int]
VALUE_FIELD_NUMBER: _ClassVar[int]
key: str
value: str
def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ...
class DeviceInfo(_message.Message):
__slots__ = ("identifiers", "connections", "via_device", "entry_type", "name", "manufacturer", "model", "model_id", "sw_version", "hw_version", "serial_number", "suggested_area", "configuration_url", "default_name", "default_manufacturer", "default_model", "translation_key")
IDENTIFIERS_FIELD_NUMBER: _ClassVar[int]
CONNECTIONS_FIELD_NUMBER: _ClassVar[int]
VIA_DEVICE_FIELD_NUMBER: _ClassVar[int]
ENTRY_TYPE_FIELD_NUMBER: _ClassVar[int]
NAME_FIELD_NUMBER: _ClassVar[int]
MANUFACTURER_FIELD_NUMBER: _ClassVar[int]
MODEL_FIELD_NUMBER: _ClassVar[int]
MODEL_ID_FIELD_NUMBER: _ClassVar[int]
SW_VERSION_FIELD_NUMBER: _ClassVar[int]
HW_VERSION_FIELD_NUMBER: _ClassVar[int]
SERIAL_NUMBER_FIELD_NUMBER: _ClassVar[int]
SUGGESTED_AREA_FIELD_NUMBER: _ClassVar[int]
CONFIGURATION_URL_FIELD_NUMBER: _ClassVar[int]
DEFAULT_NAME_FIELD_NUMBER: _ClassVar[int]
DEFAULT_MANUFACTURER_FIELD_NUMBER: _ClassVar[int]
DEFAULT_MODEL_FIELD_NUMBER: _ClassVar[int]
TRANSLATION_KEY_FIELD_NUMBER: _ClassVar[int]
identifiers: _containers.RepeatedCompositeFieldContainer[DevicePair]
connections: _containers.RepeatedCompositeFieldContainer[DevicePair]
via_device: DevicePair
entry_type: str
name: str
manufacturer: str
model: str
model_id: str
sw_version: str
hw_version: str
serial_number: str
suggested_area: str
configuration_url: str
default_name: str
default_manufacturer: str
default_model: str
translation_key: str
def __init__(self, identifiers: _Optional[_Iterable[_Union[DevicePair, _Mapping]]] = ..., connections: _Optional[_Iterable[_Union[DevicePair, _Mapping]]] = ..., via_device: _Optional[_Union[DevicePair, _Mapping]] = ..., entry_type: _Optional[str] = ..., name: _Optional[str] = ..., manufacturer: _Optional[str] = ..., model: _Optional[str] = ..., model_id: _Optional[str] = ..., sw_version: _Optional[str] = ..., hw_version: _Optional[str] = ..., serial_number: _Optional[str] = ..., suggested_area: _Optional[str] = ..., configuration_url: _Optional[str] = ..., default_name: _Optional[str] = ..., default_manufacturer: _Optional[str] = ..., default_model: _Optional[str] = ..., translation_key: _Optional[str] = ...) -> None: ...
class EntrySetup(_message.Message):
__slots__ = ("entry_id", "domain", "title", "data", "options", "source", "unique_id", "version", "minor_version")
ENTRY_ID_FIELD_NUMBER: _ClassVar[int]
DOMAIN_FIELD_NUMBER: _ClassVar[int]
TITLE_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
OPTIONS_FIELD_NUMBER: _ClassVar[int]
SOURCE_FIELD_NUMBER: _ClassVar[int]
UNIQUE_ID_FIELD_NUMBER: _ClassVar[int]
VERSION_FIELD_NUMBER: _ClassVar[int]
MINOR_VERSION_FIELD_NUMBER: _ClassVar[int]
entry_id: str
domain: str
title: str
data: _struct_pb2.Struct
options: _struct_pb2.Struct
source: str
unique_id: str
version: int
minor_version: int
def __init__(self, entry_id: _Optional[str] = ..., domain: _Optional[str] = ..., title: _Optional[str] = ..., data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., options: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., source: _Optional[str] = ..., unique_id: _Optional[str] = ..., version: _Optional[int] = ..., minor_version: _Optional[int] = ...) -> None: ...
class EntrySetupResult(_message.Message):
__slots__ = ("ok", "reason")
OK_FIELD_NUMBER: _ClassVar[int]
REASON_FIELD_NUMBER: _ClassVar[int]
ok: bool
reason: str
def __init__(self, ok: bool = ..., reason: _Optional[str] = ...) -> None: ...
class EntryUnload(_message.Message):
__slots__ = ("entry_id",)
ENTRY_ID_FIELD_NUMBER: _ClassVar[int]
entry_id: str
def __init__(self, entry_id: _Optional[str] = ...) -> None: ...
class EntryUnloadResult(_message.Message):
__slots__ = ("ok",)
OK_FIELD_NUMBER: _ClassVar[int]
ok: bool
def __init__(self, ok: bool = ...) -> None: ...
class CallService(_message.Message):
__slots__ = ("domain", "service", "target", "service_data", "context_id", "return_response")
DOMAIN_FIELD_NUMBER: _ClassVar[int]
SERVICE_FIELD_NUMBER: _ClassVar[int]
TARGET_FIELD_NUMBER: _ClassVar[int]
SERVICE_DATA_FIELD_NUMBER: _ClassVar[int]
CONTEXT_ID_FIELD_NUMBER: _ClassVar[int]
RETURN_RESPONSE_FIELD_NUMBER: _ClassVar[int]
domain: str
service: str
target: _struct_pb2.Struct
service_data: _struct_pb2.Struct
context_id: str
return_response: bool
def __init__(self, domain: _Optional[str] = ..., service: _Optional[str] = ..., target: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., service_data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., context_id: _Optional[str] = ..., return_response: bool = ...) -> None: ...
class ServiceResponse(_message.Message):
__slots__ = ("data",)
DATA_FIELD_NUMBER: _ClassVar[int]
data: _struct_pb2.Struct
def __init__(self, data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class CallServiceResult(_message.Message):
__slots__ = ("response",)
RESPONSE_FIELD_NUMBER: _ClassVar[int]
response: ServiceResponse
def __init__(self, response: _Optional[_Union[ServiceResponse, _Mapping]] = ...) -> None: ...
class Shutdown(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
class ShutdownResult(_message.Message):
__slots__ = ("ok", "unloaded", "restore_state")
OK_FIELD_NUMBER: _ClassVar[int]
UNLOADED_FIELD_NUMBER: _ClassVar[int]
RESTORE_STATE_FIELD_NUMBER: _ClassVar[int]
ok: bool
unloaded: int
restore_state: _struct_pb2.Struct
def __init__(self, ok: bool = ..., unloaded: _Optional[int] = ..., restore_state: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class Ping(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
class PingResult(_message.Message):
__slots__ = ("pong",)
PONG_FIELD_NUMBER: _ClassVar[int]
pong: str
def __init__(self, pong: _Optional[str] = ...) -> None: ...
class Ready(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
class FlowInit(_message.Message):
__slots__ = ("handler", "context", "data")
HANDLER_FIELD_NUMBER: _ClassVar[int]
CONTEXT_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
handler: str
context: _struct_pb2.Struct
data: _struct_pb2.Struct
def __init__(self, handler: _Optional[str] = ..., context: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class FlowStep(_message.Message):
__slots__ = ("flow_id", "user_input")
FLOW_ID_FIELD_NUMBER: _ClassVar[int]
USER_INPUT_FIELD_NUMBER: _ClassVar[int]
flow_id: str
user_input: _struct_pb2.Struct
def __init__(self, flow_id: _Optional[str] = ..., user_input: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class FlowAbort(_message.Message):
__slots__ = ("flow_id",)
FLOW_ID_FIELD_NUMBER: _ClassVar[int]
flow_id: str
def __init__(self, flow_id: _Optional[str] = ...) -> None: ...
class FlowAbortResult(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
class FlowResult(_message.Message):
__slots__ = ("type", "flow_id", "handler", "step_id", "reason", "title", "description", "last_step", "preview", "version", "minor_version", "data", "options", "errors", "description_placeholders", "context", "data_schema", "has_data_schema")
TYPE_FIELD_NUMBER: _ClassVar[int]
FLOW_ID_FIELD_NUMBER: _ClassVar[int]
HANDLER_FIELD_NUMBER: _ClassVar[int]
STEP_ID_FIELD_NUMBER: _ClassVar[int]
REASON_FIELD_NUMBER: _ClassVar[int]
TITLE_FIELD_NUMBER: _ClassVar[int]
DESCRIPTION_FIELD_NUMBER: _ClassVar[int]
LAST_STEP_FIELD_NUMBER: _ClassVar[int]
PREVIEW_FIELD_NUMBER: _ClassVar[int]
VERSION_FIELD_NUMBER: _ClassVar[int]
MINOR_VERSION_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
OPTIONS_FIELD_NUMBER: _ClassVar[int]
ERRORS_FIELD_NUMBER: _ClassVar[int]
DESCRIPTION_PLACEHOLDERS_FIELD_NUMBER: _ClassVar[int]
CONTEXT_FIELD_NUMBER: _ClassVar[int]
DATA_SCHEMA_FIELD_NUMBER: _ClassVar[int]
HAS_DATA_SCHEMA_FIELD_NUMBER: _ClassVar[int]
type: str
flow_id: str
handler: str
step_id: str
reason: str
title: str
description: str
last_step: bool
preview: str
version: int
minor_version: int
data: _struct_pb2.Struct
options: _struct_pb2.Struct
errors: _struct_pb2.Struct
description_placeholders: _struct_pb2.Struct
context: _struct_pb2.Struct
data_schema: _struct_pb2.ListValue
has_data_schema: bool
def __init__(self, type: _Optional[str] = ..., flow_id: _Optional[str] = ..., handler: _Optional[str] = ..., step_id: _Optional[str] = ..., reason: _Optional[str] = ..., title: _Optional[str] = ..., description: _Optional[str] = ..., last_step: bool = ..., preview: _Optional[str] = ..., version: _Optional[int] = ..., minor_version: _Optional[int] = ..., data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., options: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., errors: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., description_placeholders: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., context: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., data_schema: _Optional[_Union[_struct_pb2.ListValue, _Mapping]] = ..., has_data_schema: bool = ...) -> None: ...
class EntityInfo(_message.Message):
__slots__ = ("description", "device_info")
class Description(_message.Message):
__slots__ = ("name", "icon", "entity_category", "device_class", "supported_features", "translation_key")
NAME_FIELD_NUMBER: _ClassVar[int]
ICON_FIELD_NUMBER: _ClassVar[int]
ENTITY_CATEGORY_FIELD_NUMBER: _ClassVar[int]
DEVICE_CLASS_FIELD_NUMBER: _ClassVar[int]
SUPPORTED_FEATURES_FIELD_NUMBER: _ClassVar[int]
TRANSLATION_KEY_FIELD_NUMBER: _ClassVar[int]
name: str
icon: str
entity_category: str
device_class: str
supported_features: int
translation_key: str
def __init__(self, name: _Optional[str] = ..., icon: _Optional[str] = ..., entity_category: _Optional[str] = ..., device_class: _Optional[str] = ..., supported_features: _Optional[int] = ..., translation_key: _Optional[str] = ...) -> None: ...
DESCRIPTION_FIELD_NUMBER: _ClassVar[int]
DEVICE_INFO_FIELD_NUMBER: _ClassVar[int]
description: EntityInfo.Description
device_info: DeviceInfo
def __init__(self, description: _Optional[_Union[EntityInfo.Description, _Mapping]] = ..., device_info: _Optional[_Union[DeviceInfo, _Mapping]] = ...) -> None: ...
class InitialState(_message.Message):
__slots__ = ("state", "capabilities", "attributes")
STATE_FIELD_NUMBER: _ClassVar[int]
CAPABILITIES_FIELD_NUMBER: _ClassVar[int]
ATTRIBUTES_FIELD_NUMBER: _ClassVar[int]
state: str
capabilities: _struct_pb2.Struct
attributes: _struct_pb2.Struct
def __init__(self, state: _Optional[str] = ..., capabilities: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., attributes: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class EntityDescription(_message.Message):
__slots__ = ("entry_id", "domain", "sandbox_entity_id", "unique_id", "has_entity_name", "info", "initial")
ENTRY_ID_FIELD_NUMBER: _ClassVar[int]
DOMAIN_FIELD_NUMBER: _ClassVar[int]
SANDBOX_ENTITY_ID_FIELD_NUMBER: _ClassVar[int]
UNIQUE_ID_FIELD_NUMBER: _ClassVar[int]
HAS_ENTITY_NAME_FIELD_NUMBER: _ClassVar[int]
INFO_FIELD_NUMBER: _ClassVar[int]
INITIAL_FIELD_NUMBER: _ClassVar[int]
entry_id: str
domain: str
sandbox_entity_id: str
unique_id: str
has_entity_name: bool
info: EntityInfo
initial: InitialState
def __init__(self, entry_id: _Optional[str] = ..., domain: _Optional[str] = ..., sandbox_entity_id: _Optional[str] = ..., unique_id: _Optional[str] = ..., has_entity_name: bool = ..., info: _Optional[_Union[EntityInfo, _Mapping]] = ..., initial: _Optional[_Union[InitialState, _Mapping]] = ...) -> None: ...
class RegisterEntityResult(_message.Message):
__slots__ = ("entity_id",)
ENTITY_ID_FIELD_NUMBER: _ClassVar[int]
entity_id: str
def __init__(self, entity_id: _Optional[str] = ...) -> None: ...
class UnregisterEntity(_message.Message):
__slots__ = ("sandbox_entity_id",)
SANDBOX_ENTITY_ID_FIELD_NUMBER: _ClassVar[int]
sandbox_entity_id: str
def __init__(self, sandbox_entity_id: _Optional[str] = ...) -> None: ...
class UnregisterEntityResult(_message.Message):
__slots__ = ("ok",)
OK_FIELD_NUMBER: _ClassVar[int]
ok: bool
def __init__(self, ok: bool = ...) -> None: ...
class StateChanged(_message.Message):
__slots__ = ("sandbox_entity_id", "state", "attributes", "context_id")
SANDBOX_ENTITY_ID_FIELD_NUMBER: _ClassVar[int]
STATE_FIELD_NUMBER: _ClassVar[int]
ATTRIBUTES_FIELD_NUMBER: _ClassVar[int]
CONTEXT_ID_FIELD_NUMBER: _ClassVar[int]
sandbox_entity_id: str
state: str
attributes: _struct_pb2.Struct
context_id: str
def __init__(self, sandbox_entity_id: _Optional[str] = ..., state: _Optional[str] = ..., attributes: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., context_id: _Optional[str] = ...) -> None: ...
class RegisterService(_message.Message):
__slots__ = ("domain", "service", "supports_response", "schema")
DOMAIN_FIELD_NUMBER: _ClassVar[int]
SERVICE_FIELD_NUMBER: _ClassVar[int]
SUPPORTS_RESPONSE_FIELD_NUMBER: _ClassVar[int]
SCHEMA_FIELD_NUMBER: _ClassVar[int]
domain: str
service: str
supports_response: str
schema: _struct_pb2.ListValue
def __init__(self, domain: _Optional[str] = ..., service: _Optional[str] = ..., supports_response: _Optional[str] = ..., schema: _Optional[_Union[_struct_pb2.ListValue, _Mapping]] = ...) -> None: ...
class RegisterServiceResult(_message.Message):
__slots__ = ("ok", "installed")
OK_FIELD_NUMBER: _ClassVar[int]
INSTALLED_FIELD_NUMBER: _ClassVar[int]
ok: bool
installed: bool
def __init__(self, ok: bool = ..., installed: bool = ...) -> None: ...
class UnregisterService(_message.Message):
__slots__ = ("domain", "service")
DOMAIN_FIELD_NUMBER: _ClassVar[int]
SERVICE_FIELD_NUMBER: _ClassVar[int]
domain: str
service: str
def __init__(self, domain: _Optional[str] = ..., service: _Optional[str] = ...) -> None: ...
class UnregisterServiceResult(_message.Message):
__slots__ = ("ok", "removed")
OK_FIELD_NUMBER: _ClassVar[int]
REMOVED_FIELD_NUMBER: _ClassVar[int]
ok: bool
removed: bool
def __init__(self, ok: bool = ..., removed: bool = ...) -> None: ...
class FireEvent(_message.Message):
__slots__ = ("event_type", "event_data", "context_id")
EVENT_TYPE_FIELD_NUMBER: _ClassVar[int]
EVENT_DATA_FIELD_NUMBER: _ClassVar[int]
CONTEXT_ID_FIELD_NUMBER: _ClassVar[int]
event_type: str
event_data: _struct_pb2.Struct
context_id: str
def __init__(self, event_type: _Optional[str] = ..., event_data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., context_id: _Optional[str] = ...) -> None: ...
class StoreLoad(_message.Message):
__slots__ = ("key",)
KEY_FIELD_NUMBER: _ClassVar[int]
key: str
def __init__(self, key: _Optional[str] = ...) -> None: ...
class StoreLoadResult(_message.Message):
__slots__ = ("data",)
DATA_FIELD_NUMBER: _ClassVar[int]
data: _struct_pb2.Struct
def __init__(self, data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class StoreSave(_message.Message):
__slots__ = ("key", "data")
KEY_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
key: str
data: _struct_pb2.Struct
def __init__(self, key: _Optional[str] = ..., data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class StoreSaveResult(_message.Message):
__slots__ = ("ok",)
OK_FIELD_NUMBER: _ClassVar[int]
ok: bool
def __init__(self, ok: bool = ...) -> None: ...
class StoreRemove(_message.Message):
__slots__ = ("key",)
KEY_FIELD_NUMBER: _ClassVar[int]
key: str
def __init__(self, key: _Optional[str] = ...) -> None: ...
class StoreRemoveResult(_message.Message):
__slots__ = ("ok",)
OK_FIELD_NUMBER: _ClassVar[int]
ok: bool
def __init__(self, ok: bool = ...) -> None: ...
+17 -13
View File
@@ -103,9 +103,15 @@ class Frame:
return cls(FrameKind.PUSH, id=0, type=msg_type, payload=payload)
@classmethod
def ok_response(cls, call_id: int, result: Any) -> Frame:
"""Build a success response frame."""
return cls(FrameKind.RESPONSE, id=call_id, ok=True, result=result)
def ok_response(cls, call_id: int, result: Any, msg_type: str = "") -> Frame:
"""Build a success response frame.
``msg_type`` is carried so a stateless codec (the protobuf one) can
look up the result message class on encode + decode.
"""
return cls(
FrameKind.RESPONSE, id=call_id, type=msg_type, ok=True, result=result
)
@classmethod
def error_response(
@@ -114,11 +120,13 @@ class Frame:
error: str,
error_type: str | None,
error_data: dict[str, Any] | None = None,
msg_type: str = "",
) -> Frame:
"""Build a failure response frame."""
return cls(
FrameKind.RESPONSE,
id=call_id,
type=msg_type,
ok=False,
error=error,
error_type=error_type,
@@ -412,9 +420,7 @@ class Channel:
try:
data = await self._transport.read_frame()
except FrameTooLargeError as err:
_LOGGER.error(
"channel %s: %s; aborting channel", self._name, err
)
_LOGGER.error("channel %s: %s; aborting channel", self._name, err)
return
if data is None:
return
@@ -438,9 +444,7 @@ class Channel:
for future in self._pending.values():
if not future.done():
future.set_exception(
ChannelClosedError(
f"channel {self._name!r} stream ended"
)
ChannelClosedError(f"channel {self._name!r} stream ended")
)
self._pending.clear()
for task in list(self._inflight):
@@ -480,6 +484,7 @@ class Channel:
frame.id,
f"no handler for {frame.type!r}",
"ChannelUnknownType",
msg_type=frame.type,
)
)
)
@@ -491,9 +496,7 @@ class Channel:
def _spawn_handler(self, coro: Coroutine[Any, Any, Any]) -> None:
"""Start a handler task and track it for cancellation on close."""
task = asyncio.create_task(
coro, name=f"sandbox_v2[{self._name}]:dispatch"
)
task = asyncio.create_task(coro, name=f"sandbox_v2[{self._name}]:dispatch")
self._inflight.add(task)
task.add_done_callback(self._inflight.discard)
@@ -534,6 +537,7 @@ class Channel:
str(err) or err.__class__.__name__,
err.__class__.__name__,
error_data_for(err),
msg_type=msg_type,
)
with contextlib.suppress(Exception):
await self._write(frame)
@@ -541,7 +545,7 @@ class Channel:
if self._closed:
return
with contextlib.suppress(Exception):
await self._write(Frame.ok_response(call_id, result))
await self._write(Frame.ok_response(call_id, result, msg_type))
__all__ = [
@@ -0,0 +1,134 @@
"""Protobuf :class:`~.channel.Codec` — the production wire.
Serialises a :class:`~.channel.Frame` to the protobuf ``Frame`` envelope and
back. The envelope carries ``type`` on responses too, so this stateless codec
can look up the result message class from ``frame.type`` on both encode and
decode the dispatch core never has to know about proto types (the registry
lives here, not on :meth:`Channel.register`).
Mirrored verbatim across the no-cross-import boundary (the same file lives at
``hass_client.codec_protobuf``); the relative imports resolve to each side's
own :mod:`messages` + ``_proto`` gencode.
"""
from typing import Any
from google.protobuf.message import Message
from ._proto import sandbox_v2_pb2 as pb
from .channel import Frame, FrameKind
from .messages import REGISTRY
Registry = dict[str, tuple[type[Message], type[Message] | None]]
class ProtobufCodec:
"""Encode/decode :class:`Frame` objects as protobuf ``Frame`` envelopes."""
def __init__(self, registry: Registry | None = None) -> None:
"""Build the codec over a ``type → (request_cls, result_cls)`` map."""
self._registry = registry if registry is not None else REGISTRY
def _classes(
self, msg_type: str
) -> tuple[type[Message] | None, type[Message] | None]:
return self._registry.get(msg_type, (None, None))
def encode(self, frame: Frame) -> bytes:
"""Serialise a frame to the protobuf ``Frame`` envelope bytes."""
envelope = pb.Frame(id=frame.id, type=frame.type)
if frame.kind is FrameKind.RESPONSE:
response = envelope.response
response.ok = frame.ok
if frame.ok:
_, result_cls = self._classes(frame.type)
response.result = _serialize_body(frame.result, result_cls)
else:
_fill_error(response.error, frame)
else:
request_cls, _ = self._classes(frame.type)
envelope.request = _serialize_body(frame.payload, request_cls)
return envelope.SerializeToString()
def decode(self, data: bytes) -> Frame:
"""Rebuild a frame from protobuf ``Frame`` envelope bytes."""
envelope = pb.Frame.FromString(data)
msg_type = envelope.type
body = envelope.WhichOneof("body")
if body == "response":
response = envelope.response
if response.ok:
_, result_cls = self._classes(msg_type)
result = _parse_body(response.result, result_cls)
return Frame.ok_response(envelope.id, result, msg_type)
error, error_type, error_data = _read_error(response.error)
return Frame.error_response(
envelope.id, error, error_type, error_data, msg_type
)
request_cls, _ = self._classes(msg_type)
payload = _parse_body(envelope.request, request_cls)
if envelope.id == 0:
return Frame.push(msg_type, payload)
return Frame.call(envelope.id, msg_type, payload)
def _serialize_body(body: Any, cls: type[Message] | None) -> bytes:
"""Serialise a proto-message body; ``None`` becomes an empty message."""
if body is None:
return cls().SerializeToString() if cls is not None else b""
if isinstance(body, Message):
return body.SerializeToString()
raise TypeError(
f"ProtobufCodec expected a proto message body, got {type(body).__name__}"
)
def _parse_body(raw: bytes, cls: type[Message] | None) -> Any:
"""Deserialise a body into ``cls``; an unregistered type decodes to None."""
if cls is None:
return None
return cls.FromString(raw)
def _fill_error(error: pb.Error, frame: Frame) -> None:
"""Populate the proto ``Error`` from a failure frame.
Carries fidelity #7's structured voluptuous data: the ``multiple`` flag
distinguishes a ``MultipleInvalid`` from a single ``Invalid`` so the peer
rebuilds the right exception.
"""
error.message = frame.error or ""
error.type = frame.error_type or ""
data = frame.error_data
if not data:
return
if data.get("kind") == "multiple":
error.multiple = True
for child in data.get("errors", []):
error.invalid.add(message=child.get("msg", ""), path=child.get("path", []))
elif data.get("kind") == "invalid":
error.invalid.add(message=data.get("msg", ""), path=data.get("path", []))
def _read_error(error: pb.Error) -> tuple[str, str | None, dict[str, Any] | None]:
"""Rebuild ``(message, type, error_data)`` from the proto ``Error``."""
error_data: dict[str, Any] | None = None
if error.multiple:
error_data = {
"kind": "multiple",
"errors": [
{"kind": "invalid", "msg": item.message, "path": list(item.path)}
for item in error.invalid
],
}
elif len(error.invalid) == 1:
item = error.invalid[0]
error_data = {
"kind": "invalid",
"msg": item.message,
"path": list(item.path),
}
return error.message, (error.type or None), error_data
__all__ = ["ProtobufCodec"]
@@ -26,8 +26,10 @@ from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import DATA_INSTANCES
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED
from ._proto import sandbox_v2_pb2 as pb
from .approved_domains import ApprovedDomains
from .channel import Channel
from .messages import make_entity_description
from .protocol import MSG_REGISTER_ENTITY, MSG_STATE_CHANGED, MSG_UNREGISTER_ENTITY
_LOGGER = logging.getLogger(__name__)
@@ -198,11 +200,16 @@ class EntityBridge:
if payload is None:
return
new_hash = _payload_hash(payload)
initial_state = None
initial_attributes = None
if hasattr(new_state, "state"):
payload["initial_state"] = new_state.state
payload["initial_attributes"] = dict(new_state.attributes)
initial_state = new_state.state
initial_attributes = dict(new_state.attributes)
try:
await self._channel.call(MSG_REGISTER_ENTITY, payload)
await self._channel.call(
MSG_REGISTER_ENTITY,
_to_entity_description(payload, initial_state, initial_attributes),
)
except Exception:
_LOGGER.exception("EntityBridge: register failed for %s", entity_id)
return
@@ -228,12 +235,17 @@ class EntityBridge:
new_hash = _payload_hash(payload)
if self._last_hash.get(entity_id) == new_hash:
return
initial_state = None
initial_attributes = None
state = self.hass.states.get(entity_id)
if state is not None:
payload["initial_state"] = state.state
payload["initial_attributes"] = dict(state.attributes)
initial_state = state.state
initial_attributes = dict(state.attributes)
try:
await self._channel.call(MSG_REGISTER_ENTITY, payload)
await self._channel.call(
MSG_REGISTER_ENTITY,
_to_entity_description(payload, initial_state, initial_attributes),
)
except Exception:
_LOGGER.exception("EntityBridge: resend failed for %s", entity_id)
return
@@ -242,15 +254,17 @@ class EntityBridge:
async def _push_state(self, entity_id: str, new_state: Any) -> None:
if self._channel is None:
return
payload = {
"sandbox_entity_id": entity_id,
"new_state": {
"state": new_state.state,
"attributes": dict(new_state.attributes),
},
}
msg = pb.StateChanged(sandbox_entity_id=entity_id)
if new_state.state is not None:
msg.state = new_state.state
msg.attributes.update(dict(new_state.attributes))
# Forward only the context id — never parent_id / user_id. Main
# resolves it to a Context attributed to the sandbox system user.
context = getattr(new_state, "context", None)
if context is not None and context.id:
msg.context_id = context.id
try:
await self._channel.push(MSG_STATE_CHANGED, payload)
await self._channel.push(MSG_STATE_CHANGED, msg)
except Exception:
_LOGGER.exception("EntityBridge: state push failed for %s", entity_id)
@@ -259,12 +273,34 @@ class EntityBridge:
return
try:
await self._channel.call(
MSG_UNREGISTER_ENTITY, {"sandbox_entity_id": entity_id}
MSG_UNREGISTER_ENTITY, pb.UnregisterEntity(sandbox_entity_id=entity_id)
)
except Exception:
_LOGGER.exception(
"EntityBridge: unregister failed for %s", entity_id
)
_LOGGER.exception("EntityBridge: unregister failed for %s", entity_id)
def _to_entity_description(
payload: dict[str, Any],
initial_state: str | None,
initial_attributes: dict[str, Any] | None,
) -> pb.EntityDescription:
"""Build the typed ``EntityDescription`` message from a describe dict."""
return make_entity_description(
entry_id=payload["entry_id"],
domain=payload["domain"],
sandbox_entity_id=payload["sandbox_entity_id"],
unique_id=payload.get("unique_id"),
name=payload.get("name"),
icon=payload.get("icon"),
has_entity_name=bool(payload.get("has_entity_name", False)),
entity_category=payload.get("entity_category"),
device_class=payload.get("device_class"),
supported_features=int(payload.get("supported_features") or 0),
capabilities=payload.get("capabilities"),
initial_state=initial_state,
initial_attributes=initial_attributes,
device_info=payload.get("device_info"),
)
def _payload_hash(payload: dict[str, Any]) -> str:
@@ -8,16 +8,16 @@ and reports back. Main holds the canonical entry; the sandbox copy is
ephemeral state used by the integration's lifecycle hooks.
"""
from collections.abc import Mapping
import logging
from types import MappingProxyType
from typing import Any
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.core import HomeAssistant
from ._proto import sandbox_v2_pb2 as pb
from .approved_domains import ApprovedDomains
from .channel import Channel
from .messages import dict_to_struct, struct_to_dict
from .protocol import MSG_CALL_SERVICE, MSG_ENTRY_SETUP, MSG_ENTRY_UNLOAD
_LOGGER = logging.getLogger(__name__)
@@ -43,16 +43,16 @@ class EntryRunner:
channel.register(MSG_ENTRY_UNLOAD, self._handle_entry_unload)
channel.register(MSG_CALL_SERVICE, self._handle_call_service)
async def _handle_entry_setup(self, payload: Mapping[str, Any]) -> dict[str, Any]:
async def _handle_entry_setup(self, msg: pb.EntrySetup) -> pb.EntrySetupResult:
"""Build a :class:`ConfigEntry`, register it, and call async_setup."""
try:
entry = _entry_from_payload(payload)
entry = _entry_from_proto(msg)
except (KeyError, TypeError) as err:
return {"ok": False, "reason": f"bad payload: {err}"}
return pb.EntrySetupResult(ok=False, reason=f"bad payload: {err}")
config_entries = self.hass.config_entries
if config_entries.async_get_entry(entry.entry_id) is not None:
return {"ok": False, "reason": "entry already loaded"}
return pb.EntrySetupResult(ok=False, reason="entry already loaded")
# ConfigEntries doesn't expose a "add without persist" hook; the
# sandbox's instance has no Store backing, so we drop the entry
@@ -65,38 +65,35 @@ class EntryRunner:
_LOGGER.exception(
"sandbox entry_setup raised for %s (%s)", entry.title, entry.domain
)
return {"ok": False, "reason": str(err) or err.__class__.__name__}
return pb.EntrySetupResult(
ok=False, reason=str(err) or err.__class__.__name__
)
if not ok:
return {
"ok": False,
"reason": entry.reason or f"async_setup returned {ok!r}",
}
return pb.EntrySetupResult(
ok=False, reason=entry.reason or f"async_setup returned {ok!r}"
)
self.approved.add(entry.domain)
return {"ok": True}
return pb.EntrySetupResult(ok=True)
async def _handle_entry_unload(
self, payload: Mapping[str, Any]
) -> dict[str, Any]:
async def _handle_entry_unload(self, msg: pb.EntryUnload) -> pb.EntryUnloadResult:
"""Unload an entry by id and drop it from the sandbox's store."""
entry_id = payload["entry_id"]
entry_id = msg.entry_id
config_entries = self.hass.config_entries
entry = config_entries.async_get_entry(entry_id)
if entry is None:
return {"ok": True}
return pb.EntryUnloadResult(ok=True)
try:
unloaded = await config_entries.async_unload(entry_id)
except Exception as err:
except Exception:
_LOGGER.exception("sandbox entry_unload raised for %s", entry_id)
return {"ok": False, "reason": str(err) or err.__class__.__name__}
return pb.EntryUnloadResult(ok=False)
config_entries._entries.pop(entry_id, None) # noqa: SLF001
# Drop one approval refcount; another loaded entry of the same
# domain keeps it approved.
self.approved.remove(entry.domain)
return {"ok": bool(unloaded)}
return pb.EntryUnloadResult(ok=bool(unloaded))
async def _handle_call_service(
self, payload: Mapping[str, Any]
) -> Any:
async def _handle_call_service(self, msg: pb.CallService) -> pb.CallServiceResult:
"""Dispatch a main→sandbox service call through HA's normal path.
Service-handler errors propagate as raised exceptions so the
@@ -104,47 +101,46 @@ class EntryRunner:
``Invalid``). Main maps those back to ``TypeError`` /
``HomeAssistantError`` in :mod:`bridge`'s exception translator.
"""
domain = payload["domain"]
service = payload["service"]
target = payload.get("target") or {}
service_data = dict(payload.get("service_data") or {})
return_response = bool(payload.get("return_response", False))
if return_response:
target = struct_to_dict(msg.target)
service_data = struct_to_dict(msg.service_data)
if msg.return_response:
result = await self.hass.services.async_call(
domain,
service,
msg.domain,
msg.service,
service_data,
blocking=True,
target=target,
return_response=True,
)
return {"response": result}
response = pb.CallServiceResult()
response.response.data.CopyFrom(dict_to_struct(result or {}))
return response
await self.hass.services.async_call(
domain,
service,
msg.domain,
msg.service,
service_data,
blocking=True,
target=target,
)
return None
return pb.CallServiceResult()
def _entry_from_payload(payload: Mapping[str, Any]) -> ConfigEntry:
"""Rebuild a :class:`ConfigEntry` from the wire payload.
def _entry_from_proto(msg: pb.EntrySetup) -> ConfigEntry:
"""Rebuild a :class:`ConfigEntry` from the typed ``EntrySetup`` message.
Only fields the integration's setup hooks need are surfaced — the
sandbox does not persist entries or track update listeners.
"""
return ConfigEntry(
version=int(payload["version"]),
minor_version=int(payload.get("minor_version", 1)),
domain=payload["domain"],
title=payload.get("title", ""),
data=MappingProxyType(dict(payload.get("data") or {})),
options=MappingProxyType(dict(payload.get("options") or {})),
source=payload.get("source", "user"),
unique_id=payload.get("unique_id"),
entry_id=payload["entry_id"],
version=msg.version,
minor_version=msg.minor_version,
domain=msg.domain,
title=msg.title,
data=MappingProxyType(struct_to_dict(msg.data)),
options=MappingProxyType(struct_to_dict(msg.options)),
source=msg.source,
unique_id=msg.unique_id if msg.HasField("unique_id") else None,
entry_id=msg.entry_id,
discovery_keys=MappingProxyType({}),
subentries_data=None,
state=ConfigEntryState.NOT_LOADED,
@@ -40,6 +40,7 @@ from homeassistant.const import (
)
from homeassistant.core import Event, HomeAssistant, callback
from ._proto import sandbox_v2_pb2 as pb
from .approved_domains import ApprovedDomains
from .channel import Channel
from .protocol import MSG_FIRE_EVENT
@@ -101,25 +102,22 @@ class EventMirror:
return
if not self.approved.approves_event(event_type):
return
payload: dict[str, Any] = {
"event_type": event_type,
"event_data": _to_json_safe(dict(event.data)),
}
if event.context is not None:
payload["context_id"] = event.context.id
msg = pb.FireEvent(event_type=event_type)
msg.event_data.update(_to_json_safe(dict(event.data)))
# Forward only the context id — never parent_id / user_id.
if event.context is not None and event.context.id:
msg.context_id = event.context.id
asyncio.create_task( # noqa: RUF006
self._push(payload),
self._push(msg),
name=f"sandbox_v2:fire_event:{event_type}",
)
async def _push(self, payload: dict[str, Any]) -> None:
async def _push(self, msg: pb.FireEvent) -> None:
assert self._channel is not None
try:
await self._channel.push(MSG_FIRE_EVENT, payload)
await self._channel.push(MSG_FIRE_EVENT, msg)
except Exception:
_LOGGER.exception(
"EventMirror: forward failed for %s", payload["event_type"]
)
_LOGGER.exception("EventMirror: forward failed for %s", msg.event_type)
def _to_json_safe(value: Any) -> Any:
@@ -31,37 +31,35 @@ from homeassistant.config_entries import (
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType, UnknownFlow
from ._proto import sandbox_v2_pb2 as pb
from .channel import Channel
from .messages import struct_to_dict
from .schema_bridge import serialize_schema
_LOGGER = logging.getLogger(__name__)
# Fields we copy verbatim from the integration's FlowResult onto the wire.
# Anything not listed here is either skipped (``progress_task``,
# ``data_schema``) or has bespoke handling below.
_SAFE_RESULT_FIELDS = (
"type",
# Scalar optional-string fields copied verbatim from the integration's
# FlowResult onto the proto. Dynamic dicts (data / options / errors /
# description_placeholders / context) and data_schema get bespoke handling in
# ``_marshal_result``. Result types beyond FORM / CREATE_ENTRY / ABORT carry no
# extra fields (e.g. menu_options) — the main-side proxy only supports those
# three and aborts noisily on anything else.
_SCALAR_STRING_FIELDS = (
"flow_id",
"handler",
"step_id",
"errors",
"description_placeholders",
"description",
"last_step",
"preview",
"reason",
"title",
"description",
)
# Dynamic dict fields → Struct fields of the same name on the proto.
_STRUCT_FIELDS = (
"data",
"options",
"subentries",
"version",
"minor_version",
"menu_options",
"url",
"progress_action",
"translation_domain",
"context",
"errors",
"description_placeholders",
)
@@ -121,36 +119,35 @@ class FlowRunner:
flow_manager.async_abort(progress["flow_id"])
await self.hass.async_block_till_done()
async def _handle_flow_init(self, payload: Mapping[str, Any]) -> dict[str, Any]:
handler = payload["handler"]
context = dict(payload.get("context") or {})
data = payload.get("data")
async def _handle_flow_init(self, msg: pb.FlowInit) -> pb.FlowResult:
context = struct_to_dict(msg.context)
data = struct_to_dict(msg.data) if msg.HasField("data") else None
result = await self.hass.config_entries.flow.async_init(
handler, context=context, data=data
msg.handler, context=context, data=data
)
return _marshal_result(result, self.hass.config_entries.flow)
async def _handle_flow_step(self, payload: Mapping[str, Any]) -> dict[str, Any]:
flow_id = payload["flow_id"]
user_input = payload.get("user_input")
async def _handle_flow_step(self, msg: pb.FlowStep) -> pb.FlowResult:
user_input = (
struct_to_dict(msg.user_input) if msg.HasField("user_input") else None
)
result = await self.hass.config_entries.flow.async_configure(
flow_id, user_input
msg.flow_id, user_input
)
return _marshal_result(result, self.hass.config_entries.flow)
async def _handle_flow_abort(self, payload: Mapping[str, Any]) -> dict[str, Any]:
flow_id = payload["flow_id"]
async def _handle_flow_abort(self, msg: pb.FlowAbort) -> pb.FlowAbortResult:
with contextlib.suppress(UnknownFlow):
# Idempotent — main may have already given up on the flow.
self.hass.config_entries.flow.async_abort(flow_id)
return {}
self.hass.config_entries.flow.async_abort(msg.flow_id)
return pb.FlowAbortResult()
def _marshal_result(
result: Mapping[str, Any],
flow_manager: ConfigEntriesFlowManager | None = None,
) -> dict[str, Any]:
"""Strip a FlowResult down to JSON-serialisable fields.
) -> pb.FlowResult:
"""Marshal a FlowResult into the typed ``FlowResult`` message.
``data_schema`` is rendered via :func:`serialize_schema` (Phase 14)
the wire payload carries the same list-of-fields shape
@@ -159,16 +156,32 @@ def _marshal_result(
carries ``unique_id`` once the integration calls
:meth:`ConfigFlow.async_set_unique_id`) is pulled out of the live
flow when the result type doesn't already include it.
Only FORM / CREATE_ENTRY / ABORT fields are carried the main-side proxy
supports only those three and aborts noisily on anything else, so
``menu_options`` / ``subentries`` / ``url`` / are intentionally dropped.
"""
out: dict[str, Any] = {}
for key in _SAFE_RESULT_FIELDS:
if key not in result:
continue
out[key] = _to_json_safe(result[key])
if "data_schema" in result and result["data_schema"] is not None:
out = pb.FlowResult(type=_flow_type_value(result["type"]))
for key in _SCALAR_STRING_FIELDS:
value = result.get(key)
if value is not None:
setattr(out, key, str(value))
if result.get("version") is not None:
out.version = int(result["version"])
if result.get("minor_version") is not None:
out.minor_version = int(result["minor_version"])
if result.get("last_step") is not None:
out.last_step = bool(result["last_step"])
if result.get("preview") is not None:
out.preview = str(result["preview"])
for key in _STRUCT_FIELDS:
value = result.get(key)
if isinstance(value, Mapping):
getattr(out, key).update(_to_json_safe(dict(value)))
if result.get("data_schema") is not None:
serialized = serialize_schema(result["data_schema"])
if serialized is not None:
out["data_schema"] = serialized
out.data_schema.extend(serialized)
else:
# voluptuous_serialize couldn't render it; flag the gap so the
# proxy still surfaces a (schema-less) form rather than abort.
@@ -179,12 +192,15 @@ def _marshal_result(
" schema-less form",
result["data_schema"],
)
out["_has_data_schema"] = True
# FORM / SHOW_PROGRESS / EXTERNAL_STEP results don't include the
# flow's context (only CREATE_ENTRY does). Look it up so the proxy
# can mirror ``unique_id`` into its own ``self.context`` and let
# main's duplicate detection fire.
if "context" not in out and flow_manager is not None:
out.has_data_schema = True
context_value = result.get("context")
if isinstance(context_value, Mapping):
out.context.update(_to_json_safe(dict(context_value)))
elif flow_manager is not None:
# FORM / SHOW_PROGRESS / EXTERNAL_STEP results don't include the
# flow's context (only CREATE_ENTRY does). Look it up so the proxy
# can mirror ``unique_id`` into its own ``self.context`` and let
# main's duplicate detection fire.
flow_id = result.get("flow_id")
if isinstance(flow_id, str):
try:
@@ -194,10 +210,17 @@ def _marshal_result(
if partial is not None:
ctx = partial.get("context")
if isinstance(ctx, Mapping):
out["context"] = _to_json_safe(ctx)
out.context.update(_to_json_safe(dict(ctx)))
return out
def _flow_type_value(value: Any) -> str:
"""Return the string value of a FlowResult ``type`` (enum or string)."""
if isinstance(value, FlowResultType):
return value.value
return str(value)
def _to_json_safe(value: Any) -> Any:
"""Recursively coerce a value into JSON-safe primitives."""
if isinstance(value, Mapping):
@@ -0,0 +1,222 @@
"""Typed protobuf message registry + dynamic-field helpers.
This module is the codec's view of the wire: the ``type → (request_cls,
result_cls)`` registry plus the small Struct/ListValue helpers that carry the
genuinely dynamic payloads (service_data, target, state attributes,
capabilities, the wrapped Store envelope, flow ``data``/``errors``/``context``)
and the serialized voluptuous schema.
Mirrored verbatim across the no-cross-import boundary, exactly like
:mod:`channel` / :mod:`protocol`: the same file lives at
``hass_client.messages``. The relative ``._proto`` import resolves to each
side's own checked-in gencode, so the two copies are byte-identical.
Numbers note: ``google.protobuf.Struct`` stores every number as a double, so
an ``int`` that crosses inside a dynamic field comes back as a ``float``
(``255`` ``255.0``). Python's ``==`` treats the two as equal, so dict
comparisons still hold; only an ``isinstance(x, int)`` check would notice.
Everything with integer semantics that matters (``version``, ``minor_version``,
``supported_features``) is an explicit ``int32`` field, not a Struct value.
"""
from typing import Any
from google.protobuf.message import Message
# pylint: disable-next=no-name-in-module
from google.protobuf.struct_pb2 import ListValue, Struct, Value
from ._proto import sandbox_v2_pb2 as pb
# Wire type → (request message class, result message class). The result class
# is ``None`` for one-way pushes (ready / state_changed / fire_event). The
# codec resolves these from ``frame.type`` on both encode and decode.
REGISTRY: dict[str, tuple[type[Message], type[Message] | None]] = {
# handshake (push)
"sandbox_v2/ready": (pb.Ready, None),
# main → sandbox
"sandbox_v2/entry_setup": (pb.EntrySetup, pb.EntrySetupResult),
"sandbox_v2/entry_unload": (pb.EntryUnload, pb.EntryUnloadResult),
"sandbox_v2/call_service": (pb.CallService, pb.CallServiceResult),
"sandbox_v2/shutdown": (pb.Shutdown, pb.ShutdownResult),
"sandbox_v2/ping": (pb.Ping, pb.PingResult),
"sandbox_v2/flow_init": (pb.FlowInit, pb.FlowResult),
"sandbox_v2/flow_step": (pb.FlowStep, pb.FlowResult),
"sandbox_v2/flow_abort": (pb.FlowAbort, pb.FlowAbortResult),
# sandbox → main
"sandbox_v2/register_entity": (pb.EntityDescription, pb.RegisterEntityResult),
"sandbox_v2/unregister_entity": (pb.UnregisterEntity, pb.UnregisterEntityResult),
"sandbox_v2/state_changed": (pb.StateChanged, None),
"sandbox_v2/register_service": (pb.RegisterService, pb.RegisterServiceResult),
"sandbox_v2/unregister_service": (
pb.UnregisterService,
pb.UnregisterServiceResult,
),
"sandbox_v2/fire_event": (pb.FireEvent, None),
"sandbox_v2/store_load": (pb.StoreLoad, pb.StoreLoadResult),
"sandbox_v2/store_save": (pb.StoreSave, pb.StoreSaveResult),
"sandbox_v2/store_remove": (pb.StoreRemove, pb.StoreRemoveResult),
}
# --- Struct / ListValue helpers -------------------------------------------
def _value_to_py(value: Value) -> Any:
"""Convert one ``google.protobuf.Value`` into a plain Python value."""
kind = value.WhichOneof("kind")
if kind == "null_value" or kind is None:
return None
if kind == "number_value":
return value.number_value
if kind == "string_value":
return value.string_value
if kind == "bool_value":
return value.bool_value
if kind == "struct_value":
return struct_to_dict(value.struct_value)
return [_value_to_py(item) for item in value.list_value.values]
def struct_to_dict(struct: Struct) -> dict[str, Any]:
"""Convert a ``Struct`` into a plain ``dict`` (empty Struct → ``{}``)."""
return {key: _value_to_py(val) for key, val in struct.fields.items()}
def dict_to_struct(data: dict[str, Any] | None) -> Struct:
"""Convert a ``dict`` (or ``None``) into a ``Struct``."""
struct = Struct()
if data:
struct.update(data)
return struct
def listvalue_to_list(list_value: ListValue) -> list[Any]:
"""Convert a ``ListValue`` into a plain ``list``."""
return [_value_to_py(item) for item in list_value.values]
def list_to_listvalue(items: list[Any] | None) -> ListValue:
"""Convert a ``list`` (or ``None``) into a ``ListValue``."""
list_value = ListValue()
if items:
list_value.extend(items)
return list_value
# --- DeviceInfo bridging --------------------------------------------------
# Scalar string fields of the DeviceInfo proto, copied through verbatim when
# present in the JSON-flattened device_info dict.
_DEVICE_INFO_SCALARS = (
"entry_type",
"name",
"manufacturer",
"model",
"model_id",
"sw_version",
"hw_version",
"serial_number",
"suggested_area",
"configuration_url",
"default_name",
"default_manufacturer",
"default_model",
"translation_key",
)
def device_info_to_proto(flat: dict[str, Any] | None) -> pb.DeviceInfo | None:
"""Build a ``DeviceInfo`` proto from the JSON-flattened device_info dict.
The sandbox-side serializer (``entity_bridge._serialise_device_info``)
already flattens sets/tuples/enums: ``identifiers`` / ``connections`` are
lists of two-element lists, ``via_device`` is a two-element list, and
``entry_type`` is the enum's string value. This maps that shape onto the
explicit proto fields.
"""
if not flat:
return None
info = pb.DeviceInfo()
for key, raw in flat.items():
if raw is None:
continue
if key in ("identifiers", "connections"):
for pair in raw:
if len(pair) == 2:
getattr(info, key).add(key=str(pair[0]), value=str(pair[1]))
elif key == "via_device":
if len(raw) == 2:
info.via_device.key = str(raw[0])
info.via_device.value = str(raw[1])
elif key in _DEVICE_INFO_SCALARS:
setattr(info, key, str(raw))
return info
def make_entity_description(
*,
entry_id: str,
domain: str,
sandbox_entity_id: str,
unique_id: str | None = None,
name: str | None = None,
icon: str | None = None,
has_entity_name: bool = False,
entity_category: str | None = None,
device_class: str | None = None,
supported_features: int = 0,
translation_key: str | None = None,
capabilities: dict[str, Any] | None = None,
initial_state: str | None = None,
initial_attributes: dict[str, Any] | None = None,
device_info: dict[str, Any] | None = None,
) -> pb.EntityDescription:
"""Build a nested ``EntityDescription`` proto from flat fields.
Used by the sandbox entity bridge and by tests so neither has to hand-nest
the ``EntityInfo`` / ``InitialState`` sub-messages. ``device_info`` is the
JSON-flattened dict the entity bridge produces (see
:func:`device_info_to_proto`).
"""
msg = pb.EntityDescription(
entry_id=entry_id,
domain=domain,
sandbox_entity_id=sandbox_entity_id,
has_entity_name=has_entity_name,
)
if unique_id is not None:
msg.unique_id = unique_id
description = msg.info.description
if name is not None:
description.name = name
if icon is not None:
description.icon = icon
if entity_category is not None:
description.entity_category = entity_category
if device_class is not None:
description.device_class = device_class
description.supported_features = int(supported_features or 0)
if translation_key is not None:
description.translation_key = translation_key
device = device_info_to_proto(device_info)
if device is not None:
msg.info.device_info.CopyFrom(device)
if initial_state is not None:
msg.initial.state = initial_state
if capabilities:
msg.initial.capabilities.update(capabilities)
if initial_attributes:
msg.initial.attributes.update(initial_attributes)
return msg
__all__ = [
"REGISTRY",
"device_info_to_proto",
"dict_to_struct",
"list_to_listvalue",
"listvalue_to_list",
"make_entity_description",
"struct_to_dict",
]
+16 -30
View File
@@ -19,7 +19,7 @@ channel — see Phase 9's :meth:`SandboxRuntime._handle_shutdown`).
"""
import asyncio
from collections.abc import Awaitable, Callable, Mapping
from collections.abc import Awaitable, Callable
import contextlib
import json
import logging
@@ -34,8 +34,10 @@ from homeassistant.core import CoreState
from homeassistant.helpers import json as json_helper, restore_state
from homeassistant.helpers.sandbox_context import current_sandbox
from ._proto import sandbox_v2_pb2 as pb
from .approved_domains import ApprovedDomains
from .channel import Channel
from .codec_protobuf import ProtobufCodec
from .entity_bridge import EntityBridge
from .entry_runner import EntryRunner
from .event_mirror import EventMirror
@@ -128,9 +130,7 @@ class SandboxRuntime:
with contextlib.suppress(NotImplementedError):
loop.add_signal_handler(sig, self._shutdown.set)
_LOGGER.info(
"sandbox_v2 runtime ready (group=%s url=%s)", self.group, self.url
)
_LOGGER.info("sandbox_v2 runtime ready (group=%s url=%s)", self.group, self.url)
# Set up the HA instance + flow runner before the marker so the
# first manager call after the handshake cannot race.
@@ -224,7 +224,7 @@ class SandboxRuntime:
"""Open a :class:`Channel` over stdin/stdout for the manager."""
return await _open_stdio_channel(name=self.group)
async def _handle_shutdown(self, _payload: Mapping[str, Any] | None) -> dict[str, Any]:
async def _handle_shutdown(self, _payload: object) -> pb.ShutdownResult:
"""Phase 9: unload entries, flush restore state, then exit cleanly.
Runs inside the channel dispatcher so the reply is written before
@@ -238,7 +238,7 @@ class SandboxRuntime:
asyncio.get_running_loop().call_soon(self._shutdown.set)
return summary
async def _run_graceful_shutdown(self) -> dict[str, Any]:
async def _run_graceful_shutdown(self) -> pb.ShutdownResult:
"""Unload every loaded entry and snapshot RestoreEntity state.
Phase 12 fires ``EVENT_HOMEASSISTANT_FINAL_WRITE`` and waits for
@@ -258,11 +258,7 @@ class SandboxRuntime:
"""
flow_runner = self._flow_runner
if flow_runner is None:
return {
"ok": True,
"unloaded": 0,
"restore_state": None,
}
return pb.ShutdownResult(ok=True, unloaded=0)
hass = flow_runner.hass
unloaded = 0
@@ -288,11 +284,9 @@ class SandboxRuntime:
hass.bus.async_fire_internal(EVENT_HOMEASSISTANT_FINAL_WRITE)
await hass.async_block_till_done()
except Exception:
_LOGGER.exception(
"sandbox %s: FINAL_WRITE flush failed", self.group
)
_LOGGER.exception("sandbox %s: FINAL_WRITE flush failed", self.group)
restore_payload: dict[str, Any] | None = None
result = pb.ShutdownResult(ok=True, unloaded=unloaded)
try:
restore_data = restore_state.async_get(hass)
stored = restore_data.async_get_stored_states()
@@ -307,20 +301,12 @@ class SandboxRuntime:
"key": restore_state.STORAGE_KEY,
"data": [item.as_dict() for item in stored],
}
_mode, json_bytes = json_helper.prepare_save_json(
wrapped, encoder=None
)
restore_payload = json.loads(json_bytes)
_mode, json_bytes = json_helper.prepare_save_json(wrapped, encoder=None)
result.restore_state.update(json.loads(json_bytes))
except Exception:
_LOGGER.exception(
"sandbox %s: restore-state collect failed", self.group
)
_LOGGER.exception("sandbox %s: restore-state collect failed", self.group)
return {
"ok": True,
"unloaded": unloaded,
"restore_state": restore_payload,
}
return result
async def _load_restore_state(hass: Any) -> None:
@@ -360,12 +346,12 @@ async def _open_stdio_channel(*, name: str) -> Channel:
os.fdopen(sys.stdout.fileno(), "wb"),
)
writer = asyncio.StreamWriter(transport, protocol, reader=None, loop=loop)
return Channel(reader, writer, name=name)
return Channel(reader, writer, name=name, codec=ProtobufCodec())
async def _handle_ping(_payload: object) -> dict[str, str]:
async def _handle_ping(_payload: object) -> pb.PingResult:
"""Health-check handler — manager-side polling uses this round-trip."""
return {"pong": "sandbox_v2"}
return pb.PingResult(pong="sandbox_v2")
__all__ = ["SandboxRuntime"]
@@ -20,7 +20,9 @@ from typing import Any
from homeassistant.helpers import json as json_helper
from homeassistant.util.json import SerializationError
from ._proto import sandbox_v2_pb2 as pb
from .channel import Channel, ChannelClosedError, ChannelRemoteError
from .messages import dict_to_struct, struct_to_dict
from .protocol import MSG_STORE_LOAD, MSG_STORE_REMOVE, MSG_STORE_SAVE
_LOGGER = logging.getLogger(__name__)
@@ -46,23 +48,16 @@ class ChannelSandboxBridge:
or ``None`` when main has no data / the channel is unavailable.
"""
try:
wrapped = await self._channel.call(MSG_STORE_LOAD, {"key": key})
result = await self._channel.call(MSG_STORE_LOAD, pb.StoreLoad(key=key))
except ChannelClosedError:
_LOGGER.warning("sandbox store[%s]: channel closed mid-load", key)
return None
except ChannelRemoteError as err:
_LOGGER.warning("sandbox store[%s] load failed: %s", key, err)
return None
if wrapped is None:
if not result.HasField("data"):
return None
if not isinstance(wrapped, dict):
_LOGGER.error(
"sandbox store[%s]: main returned non-dict (%s)",
key,
type(wrapped).__name__,
)
return None
return wrapped
return struct_to_dict(result.data)
async def async_store_save(self, key: str, data: Any) -> None:
"""Push the wrapped payload to main instead of writing to disk.
@@ -84,7 +79,9 @@ class ChannelSandboxBridge:
_LOGGER.exception("sandbox store[%s]: payload not serialisable", key)
return
try:
await self._channel.call(MSG_STORE_SAVE, {"key": key, "data": payload})
await self._channel.call(
MSG_STORE_SAVE, pb.StoreSave(key=key, data=dict_to_struct(payload))
)
except ChannelClosedError:
_LOGGER.warning("sandbox store[%s]: channel closed mid-save", key)
except ChannelRemoteError as err:
@@ -93,7 +90,7 @@ class ChannelSandboxBridge:
async def async_store_remove(self, key: str) -> None:
"""Unlink ``key`` on main, not on local disk."""
try:
await self._channel.call(MSG_STORE_REMOVE, {"key": key})
await self._channel.call(MSG_STORE_REMOVE, pb.StoreRemove(key=key))
except ChannelClosedError:
_LOGGER.warning("sandbox store[%s]: channel closed mid-remove", key)
except ChannelRemoteError as err:
@@ -35,7 +35,7 @@ def serialize_schema(schema: Any) -> list[dict[str, Any]] | None:
rendered = voluptuous_serialize.convert(
schema, custom_serializer=cv.custom_serializer
)
except (ValueError, TypeError):
except ValueError, TypeError:
return None
if not isinstance(rendered, list):
return None
@@ -26,6 +26,7 @@ from homeassistant.const import (
)
from homeassistant.core import Event, HomeAssistant, callback
from ._proto import sandbox_v2_pb2 as pb
from .approved_domains import ApprovedDomains
from .channel import Channel
from .protocol import MSG_REGISTER_SERVICE, MSG_UNREGISTER_SERVICE
@@ -92,17 +93,17 @@ class ServiceMirror:
if key in self._mirrored:
return
supports_response = _supports_response(self.hass, domain, service)
payload: dict[str, Any] = {
"domain": domain,
"service": service,
"supports_response": supports_response,
}
msg = pb.RegisterService(
domain=domain,
service=service,
supports_response=supports_response,
)
schema = _service_schema(self.hass, domain, service)
if schema is not None:
payload["schema"] = schema
if schema:
msg.schema.extend(schema)
self._mirrored.add(key)
asyncio.create_task( # noqa: RUF006
self._push_register(payload, key),
self._push_register(msg, key),
name=f"sandbox_v2:register_service:{domain}.{service}",
)
@@ -116,36 +117,36 @@ class ServiceMirror:
if key not in self._mirrored:
return
self._mirrored.discard(key)
payload = {"domain": domain, "service": service}
msg = pb.UnregisterService(domain=domain, service=service)
asyncio.create_task( # noqa: RUF006
self._push_unregister(payload),
self._push_unregister(msg),
name=f"sandbox_v2:unregister_service:{domain}.{service}",
)
async def _push_register(
self, payload: dict[str, Any], key: tuple[str, str]
self, msg: pb.RegisterService, key: tuple[str, str]
) -> None:
assert self._channel is not None
try:
await self._channel.call(MSG_REGISTER_SERVICE, payload)
await self._channel.call(MSG_REGISTER_SERVICE, msg)
except Exception:
_LOGGER.exception(
"ServiceMirror: register failed for %s.%s",
payload["domain"],
payload["service"],
msg.domain,
msg.service,
)
# Roll back the mirrored bookkeeping so a retry can succeed.
self._mirrored.discard(key)
async def _push_unregister(self, payload: dict[str, Any]) -> None:
async def _push_unregister(self, msg: pb.UnregisterService) -> None:
assert self._channel is not None
try:
await self._channel.call(MSG_UNREGISTER_SERVICE, payload)
await self._channel.call(MSG_UNREGISTER_SERVICE, msg)
except Exception:
_LOGGER.exception(
"ServiceMirror: unregister failed for %s.%s",
payload["domain"],
payload["service"],
msg.domain,
msg.service,
)
@@ -56,7 +56,7 @@ def classify_domain_sync(domain: str) -> str | None:
return GROUP_CUSTOM
try:
manifest = json.loads(manifest_path.read_text())
except (OSError, json.JSONDecodeError):
except OSError, json.JSONDecodeError:
return None
if manifest.get("integration_type") == "system":
return None
@@ -49,9 +49,7 @@ class _LoopbackWriter:
"""No-op — :meth:`close` is synchronous for the loopback."""
def make_inproc_channel_pair(
*, group: str
) -> tuple[Any, ClientChannel]:
def make_inproc_channel_pair(*, group: str) -> tuple[Any, ClientChannel]:
"""Return ``(manager_channel, runtime_channel)`` joined in-memory.
The manager-side channel is the one the HA integration's
@@ -67,9 +65,13 @@ def make_inproc_channel_pair(
# Lazy import: the manager-side Channel lives in the HA integration
# tree. Importing it eagerly would couple the testing helper to a
# component that may not be loaded.
from hass_client.codec_protobuf import ProtobufCodec as ClientCodec # noqa: PLC0415
from homeassistant.components.sandbox_v2.channel import ( # noqa: PLC0415
Channel as MgrChannel,
)
from homeassistant.components.sandbox_v2.codec_protobuf import ( # noqa: PLC0415
ProtobufCodec as MgrCodec,
)
reader_a = asyncio.StreamReader()
reader_b = asyncio.StreamReader()
@@ -77,8 +79,12 @@ def make_inproc_channel_pair(
# writer_b writes → reader_a feeds (manager reads what runtime wrote)
writer_a = _LoopbackWriter(reader_b)
writer_b = _LoopbackWriter(reader_a)
mgr_channel = MgrChannel(reader_a, writer_a, name=f"mgr:{group}")
rt_channel = ClientChannel(reader_b, writer_b, name=f"rt:{group}")
# Each side builds its own ProtobufCodec from its own _proto mirror; the
# wire is identical, so the two are fully interoperable.
mgr_channel = MgrChannel(reader_a, writer_a, name=f"mgr:{group}", codec=MgrCodec())
rt_channel = ClientChannel(
reader_b, writer_b, name=f"rt:{group}", codec=ClientCodec()
)
return mgr_channel, rt_channel
@@ -202,14 +202,10 @@ async def async_setup_inprocess_sandbox(
# Mirror what the integration's ``_on_channel_ready`` does when the
# real ``SandboxProcess`` opens its channel — register the bridge.
data.channels[group] = mgr_channel
data.bridges[group] = async_create_bridge(
hass, group=group, channel=mgr_channel
)
data.bridges[group] = async_create_bridge(hass, group=group, channel=mgr_channel)
mgr_channel.start()
return InProcessSandbox(
group=group, runtime=runtime, runtime_task=runtime_task
)
return InProcessSandbox(group=group, runtime=runtime, runtime_task=runtime_task)
def _one_shot_channel_factory(channel: Any):
@@ -219,9 +215,7 @@ def _one_shot_channel_factory(channel: Any):
async def factory() -> Any:
nonlocal used
if used:
raise RuntimeError(
"in-process SandboxRuntime asked for a second channel"
)
raise RuntimeError("in-process SandboxRuntime asked for a second channel")
used = True
return channel
+1
View File
@@ -13,6 +13,7 @@ authors = [{ name = "Paulus Schoutsen" }]
dependencies = [
"aiohttp>=3.11.0",
"homeassistant",
"protobuf==6.32.0",
]
[project.optional-dependencies]
@@ -17,9 +17,12 @@ import tempfile
from typing import Any
from unittest.mock import MagicMock
from hass_client._proto import sandbox_v2_pb2 as pb
from hass_client.channel import Channel
from hass_client.codec_protobuf import ProtobufCodec
from hass_client.entity_bridge import EntityBridge
from hass_client.flow_runner import FlowRunner
from hass_client.messages import struct_to_dict
import pytest
from homeassistant.config_entries import ConfigEntry
@@ -51,8 +54,12 @@ def _make_channel_pair() -> tuple[Channel, Channel]:
reader_a = asyncio.StreamReader()
reader_b = asyncio.StreamReader()
return (
Channel(reader_a, _LoopbackWriter(reader_b), name="main"), # type: ignore[arg-type]
Channel(reader_b, _LoopbackWriter(reader_a), name="sandbox"), # type: ignore[arg-type]
Channel(
reader_a, _LoopbackWriter(reader_b), name="main", codec=ProtobufCodec()
), # type: ignore[arg-type]
Channel(
reader_b, _LoopbackWriter(reader_a), name="sandbox", codec=ProtobufCodec()
), # type: ignore[arg-type]
)
@@ -151,11 +158,11 @@ async def test_bridge_includes_device_info_in_register_payload(
main, sandbox = channels
hass, component = hass_with_demo_component
register_calls: list[dict[str, Any]] = []
register_calls: list[pb.EntityDescription] = []
async def _on_register(payload: dict[str, Any]) -> dict[str, str]:
register_calls.append(payload)
return {"entity_id": "demo.with_device_main"}
async def _on_register(msg: pb.EntityDescription) -> pb.RegisterEntityResult:
register_calls.append(msg)
return pb.RegisterEntityResult(entity_id="demo.with_device_main")
main.register("sandbox_v2/register_entity", _on_register)
main.start()
@@ -195,11 +202,10 @@ async def test_bridge_includes_device_info_in_register_payload(
await asyncio.sleep(0)
assert len(register_calls) == 1
device_info = register_calls[0].get("device_info")
assert device_info is not None
assert device_info["identifiers"] == [["demo", "dev-1"]]
assert device_info["name"] == "Demo Device"
assert device_info["manufacturer"] == "Acme"
device_info = register_calls[0].info.device_info
assert [(p.key, p.value) for p in device_info.identifiers] == [("demo", "dev-1")]
assert device_info.name == "Demo Device"
assert device_info.manufacturer == "Acme"
await bridge.async_stop()
@@ -211,15 +217,15 @@ async def test_bridge_emits_register_and_state_pushes(
main, sandbox = channels
hass, component = hass_with_demo_component
register_calls: list[dict[str, Any]] = []
state_calls: list[dict[str, Any]] = []
register_calls: list[pb.EntityDescription] = []
state_calls: list[pb.StateChanged] = []
async def _on_register(payload: dict[str, Any]) -> dict[str, str]:
register_calls.append(payload)
return {"entity_id": "demo.lamp_main"}
async def _on_register(msg: pb.EntityDescription) -> pb.RegisterEntityResult:
register_calls.append(msg)
return pb.RegisterEntityResult(entity_id="demo.lamp_main")
async def _on_state(payload: dict[str, Any]) -> None:
state_calls.append(payload)
async def _on_state(msg: pb.StateChanged) -> None:
state_calls.append(msg)
main.register("sandbox_v2/register_entity", _on_register)
main.register("sandbox_v2/state_changed", _on_state)
@@ -257,12 +263,12 @@ async def test_bridge_emits_register_and_state_pushes(
await asyncio.sleep(0)
assert len(register_calls) == 1
payload = register_calls[0]
assert payload["unique_id"] == "demo-lamp"
assert payload["domain"] == "demo"
assert payload["sandbox_entity_id"] == "demo.lamp"
assert payload["entry_id"] == "fake-entry-id"
assert payload["initial_state"] == "off"
msg = register_calls[0]
assert msg.unique_id == "demo-lamp"
assert msg.domain == "demo"
assert msg.sandbox_entity_id == "demo.lamp"
assert msg.entry_id == "fake-entry-id"
assert msg.initial.state == "off"
# A subsequent state change becomes a state_changed push.
new_state = State(entity.entity_id, "on", {"brightness": 200})
@@ -281,16 +287,14 @@ async def test_bridge_emits_register_and_state_pushes(
await asyncio.sleep(0)
assert len(state_calls) == 1
assert state_calls[0]["sandbox_entity_id"] == "demo.lamp"
assert state_calls[0]["new_state"]["state"] == "on"
assert state_calls[0]["new_state"]["attributes"]["brightness"] == 200
assert state_calls[0].sandbox_entity_id == "demo.lamp"
assert state_calls[0].state == "on"
assert struct_to_dict(state_calls[0].attributes)["brightness"] == 200
await bridge.async_stop()
async def _register_initial(
bridge: EntityBridge, hass: Any, entity: Entity
) -> None:
async def _register_initial(bridge: EntityBridge, hass: Any, entity: Entity) -> None:
"""Drive the first state-change so ``entity`` is tracked + registered."""
now = datetime.now(tz=datetime.now().astimezone().tzinfo)
hass.bus.async_fire(
@@ -320,11 +324,11 @@ async def test_entity_registry_update_resends_registration(
main, sandbox = channels
hass, component = hass_with_demo_component
register_calls: list[dict[str, Any]] = []
register_calls: list[pb.EntityDescription] = []
async def _on_register(payload: dict[str, Any]) -> dict[str, str]:
register_calls.append(payload)
return {"entity_id": "demo.lamp_main"}
async def _on_register(msg: pb.EntityDescription) -> pb.RegisterEntityResult:
register_calls.append(msg)
return pb.RegisterEntityResult(entity_id="demo.lamp_main")
main.register("sandbox_v2/register_entity", _on_register)
main.start()
@@ -350,8 +354,8 @@ async def test_entity_registry_update_resends_registration(
await asyncio.sleep(0)
assert len(register_calls) == 2
assert register_calls[1]["name"] == "Renamed Lamp"
assert register_calls[1]["sandbox_entity_id"] == "demo.lamp"
assert register_calls[1].info.description.name == "Renamed Lamp"
assert register_calls[1].sandbox_entity_id == "demo.lamp"
# Let the resend coroutine settle past its await so the description
# hash is recorded before the next event fires.
@@ -377,11 +381,11 @@ async def test_device_registry_update_resends_linked_entities(
main, sandbox = channels
hass, component = hass_with_demo_component
register_calls: list[dict[str, Any]] = []
register_calls: list[pb.EntityDescription] = []
async def _on_register(payload: dict[str, Any]) -> dict[str, str]:
register_calls.append(payload)
return {"entity_id": "demo.lamp_main"}
async def _on_register(msg: pb.EntityDescription) -> pb.RegisterEntityResult:
register_calls.append(msg)
return pb.RegisterEntityResult(entity_id="demo.lamp_main")
main.register("sandbox_v2/register_entity", _on_register)
main.start()
@@ -436,7 +440,7 @@ async def test_device_registry_update_resends_linked_entities(
bridge.register(sandbox)
await _register_initial(bridge, hass, entity)
assert len(register_calls) == 1
assert register_calls[0]["device_info"]["sw_version"] == "1.0"
assert register_calls[0].info.device_info.sw_version == "1.0"
# Firmware bump: the entity now reports a new sw_version and the device
# registry fires its updated event.
@@ -451,6 +455,6 @@ async def test_device_registry_update_resends_linked_entities(
await asyncio.sleep(0)
assert len(register_calls) == 2
assert register_calls[1]["device_info"]["sw_version"] == "2.0"
assert register_calls[1].info.device_info.sw_version == "2.0"
await bridge.async_stop()
@@ -9,7 +9,9 @@ import tempfile
from types import ModuleType
from typing import Any
from hass_client._proto import sandbox_v2_pb2 as pb
from hass_client.channel import Channel
from hass_client.codec_protobuf import ProtobufCodec
from hass_client.entry_runner import EntryRunner
from hass_client.flow_runner import FlowRunner
import pytest
@@ -39,8 +41,12 @@ def _make_channel_pair() -> tuple[Channel, Channel]:
reader_a = asyncio.StreamReader()
reader_b = asyncio.StreamReader()
return (
Channel(reader_a, _LoopbackWriter(reader_b), name="main"), # type: ignore[arg-type]
Channel(reader_b, _LoopbackWriter(reader_a), name="sandbox"), # type: ignore[arg-type]
Channel(
reader_a, _LoopbackWriter(reader_b), name="main", codec=ProtobufCodec()
), # type: ignore[arg-type]
Channel(
reader_b, _LoopbackWriter(reader_a), name="sandbox", codec=ProtobufCodec()
), # type: ignore[arg-type]
)
@@ -95,13 +101,11 @@ async def test_entry_setup_calls_integration_setup_entry(
module.DOMAIN = "phase5_demo"
module.async_setup_entry = _async_setup_entry # type: ignore[attr-defined]
module.async_unload_entry = _async_unload_entry # type: ignore[attr-defined]
config_flow_module = ModuleType(
"homeassistant.components.phase5_demo.config_flow"
)
config_flow_module = ModuleType("homeassistant.components.phase5_demo.config_flow")
runner.hass.data[ha_loader.DATA_COMPONENTS]["phase5_demo"] = module
runner.hass.data[ha_loader.DATA_COMPONENTS][
"phase5_demo.config_flow"
] = config_flow_module
runner.hass.data[ha_loader.DATA_COMPONENTS]["phase5_demo.config_flow"] = (
config_flow_module
)
runner.hass.config.components.add("phase5_demo")
integration = ha_loader.Integration(
@@ -125,20 +129,19 @@ async def test_entry_setup_calls_integration_setup_entry(
)
runner.hass.data[ha_loader.DATA_INTEGRATIONS]["phase5_demo"] = integration
payload = {
"entry_id": "test_entry_id_5",
"domain": "phase5_demo",
"title": "Demo",
"data": {"host": "1.2.3.4"},
"options": {},
"source": "user",
"unique_id": None,
"version": 1,
"minor_version": 1,
}
payload = pb.EntrySetup(
entry_id="test_entry_id_5",
domain="phase5_demo",
title="Demo",
source="user",
version=1,
minor_version=1,
)
payload.data.update({"host": "1.2.3.4"})
result = await main.call("sandbox_v2/entry_setup", payload)
assert result == {"ok": True}
assert result.ok
assert not result.HasField("reason")
assert len(setup_calls) == 1
assert setup_calls[0].entry_id == "test_entry_id_5"
assert setup_calls[0].data["host"] == "1.2.3.4"
@@ -153,21 +156,18 @@ async def test_entry_setup_reports_failure_reason(
main.start()
sandbox.start()
payload = {
"entry_id": "missing_entry_id",
"domain": "phase5_missing",
"title": "Missing",
"data": {},
"options": {},
"source": "user",
"unique_id": None,
"version": 1,
"minor_version": 1,
}
payload = pb.EntrySetup(
entry_id="missing_entry_id",
domain="phase5_missing",
title="Missing",
source="user",
version=1,
minor_version=1,
)
result = await main.call("sandbox_v2/entry_setup", payload)
assert result["ok"] is False
assert "reason" in result
assert result.ok is False
assert result.HasField("reason")
async def test_call_service_dispatches_through_services(
@@ -186,15 +186,11 @@ async def test_call_service_dispatches_through_services(
runner.hass.services.async_register("test_call", "do_it", _svc_handler)
result = await main.call(
"sandbox_v2/call_service",
{
"domain": "test_call",
"service": "do_it",
"target": {},
"service_data": {"hello": "world"},
},
)
call_msg = pb.CallService(domain="test_call", service="do_it")
call_msg.service_data.update({"hello": "world"})
result = await main.call("sandbox_v2/call_service", call_msg)
assert result is None
# No return_response: proto result has no `response` field set (was
# `result is None` on the dict wire).
assert not result.HasField("response")
assert seen == [{"hello": "world"}]
@@ -4,10 +4,13 @@ import asyncio
import tempfile
from typing import Any
from hass_client._proto import sandbox_v2_pb2 as pb
from hass_client.approved_domains import ApprovedDomains
from hass_client.channel import Channel
from hass_client.codec_protobuf import ProtobufCodec
from hass_client.event_mirror import EventMirror
from hass_client.flow_runner import FlowRunner
from hass_client.messages import struct_to_dict
import pytest
@@ -32,8 +35,12 @@ def _make_channel_pair() -> tuple[Channel, Channel]:
reader_a = asyncio.StreamReader()
reader_b = asyncio.StreamReader()
return (
Channel(reader_a, _LoopbackWriter(reader_b), name="main"), # type: ignore[arg-type]
Channel(reader_b, _LoopbackWriter(reader_a), name="sandbox"), # type: ignore[arg-type]
Channel(
reader_a, _LoopbackWriter(reader_b), name="main", codec=ProtobufCodec()
), # type: ignore[arg-type]
Channel(
reader_b, _LoopbackWriter(reader_a), name="sandbox", codec=ProtobufCodec()
), # type: ignore[arg-type]
)
@@ -67,10 +74,10 @@ async def test_owned_domain_event_is_forwarded(
) -> None:
"""``zha_event`` reaches main when ``zha`` is approved."""
main, sandbox = channels
forwarded: list[dict[str, Any]] = []
forwarded: list[pb.FireEvent] = []
async def _on_fire(payload: dict[str, Any]) -> None:
forwarded.append(payload)
async def _on_fire(msg: pb.FireEvent) -> None:
forwarded.append(msg)
main.register("sandbox_v2/fire_event", _on_fire)
main.start()
@@ -86,8 +93,8 @@ async def test_owned_domain_event_is_forwarded(
await _wait_until(lambda: bool(forwarded))
assert len(forwarded) == 1
assert forwarded[0]["event_type"] == "zha_event"
assert forwarded[0]["event_data"]["command"] == "on"
assert forwarded[0].event_type == "zha_event"
assert struct_to_dict(forwarded[0].event_data)["command"] == "on"
await mirror.async_stop()
@@ -97,10 +104,10 @@ async def test_unapproved_event_is_dropped(
) -> None:
"""Events outside the approved-domain set don't cross the bridge."""
main, sandbox = channels
forwarded: list[dict[str, Any]] = []
forwarded: list[pb.FireEvent] = []
async def _on_fire(payload: dict[str, Any]) -> None:
forwarded.append(payload)
async def _on_fire(msg: pb.FireEvent) -> None:
forwarded.append(msg)
main.register("sandbox_v2/fire_event", _on_fire)
main.start()
@@ -128,10 +135,10 @@ async def test_internal_events_are_skipped(
) -> None:
"""``state_changed`` / ``service_registered`` are owned by other mirrors."""
main, sandbox = channels
forwarded: list[dict[str, Any]] = []
forwarded: list[pb.FireEvent] = []
async def _on_fire(payload: dict[str, Any]) -> None:
forwarded.append(payload)
async def _on_fire(msg: pb.FireEvent) -> None:
forwarded.append(msg)
main.register("sandbox_v2/fire_event", _on_fire)
main.start()
@@ -12,8 +12,11 @@ import tempfile
from types import ModuleType
from typing import Any
from hass_client._proto import sandbox_v2_pb2 as pb
from hass_client.channel import Channel
from hass_client.codec_protobuf import ProtobufCodec
from hass_client.flow_runner import FlowRunner
from hass_client.messages import dict_to_struct, listvalue_to_list, struct_to_dict
import pytest
import voluptuous as vol
@@ -42,8 +45,12 @@ def _make_channel_pair() -> tuple[Channel, Channel]:
reader_a = asyncio.StreamReader()
reader_b = asyncio.StreamReader()
return (
Channel(reader_a, _LoopbackWriter(reader_b), name="main"), # type: ignore[arg-type]
Channel(reader_b, _LoopbackWriter(reader_a), name="sandbox"), # type: ignore[arg-type]
Channel(
reader_a, _LoopbackWriter(reader_b), name="main", codec=ProtobufCodec()
), # type: ignore[arg-type]
Channel(
reader_b, _LoopbackWriter(reader_a), name="sandbox", codec=ProtobufCodec()
), # type: ignore[arg-type]
)
@@ -94,9 +101,9 @@ async def _runner_fixture() -> FlowRunner:
"homeassistant.components.phase4_demo.config_flow"
)
runner.hass.data[ha_loader.DATA_COMPONENTS]["phase4_demo"] = fake_module
runner.hass.data[ha_loader.DATA_COMPONENTS][
"phase4_demo.config_flow"
] = fake_flow_module
runner.hass.data[ha_loader.DATA_COMPONENTS]["phase4_demo.config_flow"] = (
fake_flow_module
)
runner.hass.config.components.add("phase4_demo")
try:
yield runner
@@ -118,21 +125,20 @@ async def test_flow_init_returns_form(
main.start()
sandbox.start()
result = await main.call(
"sandbox_v2/flow_init",
{"handler": "phase4_demo", "context": {"source": "user"}, "data": None},
)
init_msg = pb.FlowInit(handler="phase4_demo")
init_msg.context.update({"source": "user"})
result = await main.call("sandbox_v2/flow_init", init_msg)
assert result["type"] == "form"
assert result["step_id"] == "user"
assert result.type == "form"
assert result.step_id == "user"
# Phase 14: data_schema rides as the same list-of-fields shape
# voluptuous_serialize.convert produces, so the proxy on main can
# rebuild a usable vol.Schema (or hand the list straight to the
# frontend).
assert result["data_schema"] == [
assert listvalue_to_list(result.data_schema) == [
{"name": "host", "type": "string", "required": True}
]
assert result.get("_has_data_schema") is not True
assert result.has_data_schema is not True
async def test_flow_step_creates_entry(
@@ -144,18 +150,16 @@ async def test_flow_step_creates_entry(
main.start()
sandbox.start()
init_result = await main.call(
"sandbox_v2/flow_init",
{"handler": "phase4_demo", "context": {"source": "user"}, "data": None},
)
step_result = await main.call(
"sandbox_v2/flow_step",
{"flow_id": init_result["flow_id"], "user_input": {"host": "1.2.3.4"}},
)
init_msg = pb.FlowInit(handler="phase4_demo")
init_msg.context.update({"source": "user"})
init_result = await main.call("sandbox_v2/flow_init", init_msg)
step_msg = pb.FlowStep(flow_id=init_result.flow_id)
step_msg.user_input.CopyFrom(dict_to_struct({"host": "1.2.3.4"}))
step_result = await main.call("sandbox_v2/flow_step", step_msg)
assert step_result["type"] == "create_entry"
assert step_result["title"] == "Demo 1.2.3.4"
assert step_result["data"] == {"host": "1.2.3.4"}
assert step_result.type == "create_entry"
assert step_result.title == "Demo 1.2.3.4"
assert struct_to_dict(step_result.data) == {"host": "1.2.3.4"}
async def test_flow_step_validation_error_returns_form(
@@ -167,17 +171,15 @@ async def test_flow_step_validation_error_returns_form(
main.start()
sandbox.start()
init_result = await main.call(
"sandbox_v2/flow_init",
{"handler": "phase4_demo", "context": {"source": "user"}, "data": None},
)
step_result = await main.call(
"sandbox_v2/flow_step",
{"flow_id": init_result["flow_id"], "user_input": {"host": "bad"}},
)
init_msg = pb.FlowInit(handler="phase4_demo")
init_msg.context.update({"source": "user"})
init_result = await main.call("sandbox_v2/flow_init", init_msg)
step_msg = pb.FlowStep(flow_id=init_result.flow_id)
step_msg.user_input.CopyFrom(dict_to_struct({"host": "bad"}))
step_result = await main.call("sandbox_v2/flow_step", step_msg)
assert step_result["type"] == "form"
assert step_result["errors"] == {"host": "invalid_host"}
assert step_result.type == "form"
assert struct_to_dict(step_result.errors) == {"host": "invalid_host"}
async def test_flow_init_marshals_unique_id(
@@ -189,17 +191,12 @@ async def test_flow_init_marshals_unique_id(
main.start()
sandbox.start()
result = await main.call(
"sandbox_v2/flow_init",
{
"handler": "phase4_demo",
"context": {"source": "user", "unique_id": "demo-abc"},
"data": None,
},
)
init_msg = pb.FlowInit(handler="phase4_demo")
init_msg.context.update({"source": "user", "unique_id": "demo-abc"})
result = await main.call("sandbox_v2/flow_init", init_msg)
assert result["type"] == "form"
assert result.get("context", {}).get("unique_id") == "demo-abc"
assert result.type == "form"
assert struct_to_dict(result.context).get("unique_id") == "demo-abc"
async def test_flow_abort_is_idempotent(
@@ -212,6 +209,8 @@ async def test_flow_abort_is_idempotent(
sandbox.start()
result = await main.call(
"sandbox_v2/flow_abort", {"flow_id": "not-a-real-flow-id"}
"sandbox_v2/flow_abort", pb.FlowAbort(flow_id="not-a-real-flow-id")
)
assert result == {}
# FlowAbortResult is an empty message (was `result == {}` on the dict wire).
assert isinstance(result, pb.FlowAbortResult)
assert result.SerializeToString() == b""
@@ -20,8 +20,11 @@ from collections.abc import AsyncGenerator, Generator
import tempfile
from typing import Any
from hass_client.channel import Channel, ChannelRemoteError
from hass_client._proto import sandbox_v2_pb2 as pb
from hass_client.channel import Channel, ChannelRemoteError, JsonCodec
from hass_client.codec_protobuf import ProtobufCodec
from hass_client.flow_runner import FlowRunner
from hass_client.messages import struct_to_dict
from hass_client.sandbox_bridge import ChannelSandboxBridge
import pytest
import voluptuous as vol
@@ -283,8 +286,26 @@ def _make_channel_pair() -> tuple[Channel, Channel]:
reader_a = asyncio.StreamReader()
reader_b = asyncio.StreamReader()
return (
Channel(reader_a, _LoopbackWriter(reader_b), name="main"), # type: ignore[arg-type]
Channel(reader_b, _LoopbackWriter(reader_a), name="sandbox"), # type: ignore[arg-type]
Channel(
reader_a, _LoopbackWriter(reader_b), name="main", codec=ProtobufCodec()
), # type: ignore[arg-type]
Channel(
reader_b, _LoopbackWriter(reader_a), name="sandbox", codec=ProtobufCodec()
), # type: ignore[arg-type]
)
def _make_json_channel_pair() -> tuple[Channel, Channel]:
"""Build a JSON-codec channel pair for off-registry handlers.
Used for handlers whose payload/error types aren't in the proto
registry (e.g. ``vol.Invalid`` over an ad-hoc ``test/bad`` route).
"""
reader_a = asyncio.StreamReader()
reader_b = asyncio.StreamReader()
return (
Channel(reader_a, _LoopbackWriter(reader_b), name="main", codec=JsonCodec()), # type: ignore[arg-type]
Channel(reader_b, _LoopbackWriter(reader_a), name="sandbox", codec=JsonCodec()), # type: ignore[arg-type]
)
@@ -294,16 +315,19 @@ async def test_channel_bridge_maps_store_rpcs() -> None:
saved: dict[str, Any] = {}
removed: list[str] = []
async def _on_save(payload: dict[str, Any]) -> dict[str, bool]:
saved[payload["key"]] = payload["data"]
return {"ok": True}
async def _on_save(msg: pb.StoreSave) -> pb.StoreSaveResult:
saved[msg.key] = struct_to_dict(msg.data)
return pb.StoreSaveResult(ok=True)
async def _on_load(payload: dict[str, Any]) -> dict[str, Any] | None:
return saved.get(payload["key"])
async def _on_load(msg: pb.StoreLoad) -> pb.StoreLoadResult:
result = pb.StoreLoadResult()
if msg.key in saved:
result.data.update(saved[msg.key])
return result
async def _on_remove(payload: dict[str, Any]) -> dict[str, bool]:
removed.append(payload["key"])
return {"ok": True}
async def _on_remove(msg: pb.StoreRemove) -> pb.StoreRemoveResult:
removed.append(msg.key)
return pb.StoreRemoveResult(ok=True)
main.register("sandbox_v2/store_save", _on_save)
main.register("sandbox_v2/store_load", _on_load)
@@ -317,6 +341,9 @@ async def test_channel_bridge_maps_store_rpcs() -> None:
await bridge.async_store_save("wire", dict(wrapped))
assert saved["wire"]["data"] == {"k": "v"}
# ``async_store_load`` returns a plain dict (struct_to_dict of the
# wrapped envelope); Struct round-trips numbers as float but ``==``
# still holds against the saved dict.
assert await bridge.async_store_load("wire") == saved["wire"]
await bridge.async_store_remove("wire")
@@ -332,7 +359,9 @@ async def test_client_channel_serializes_vol_invalid() -> None:
Mirror of the main-side channel test confirms the client channel's
``error_data_for`` serialization feeds the error frame.
"""
main, sandbox = _make_channel_pair()
# ``vol.Invalid`` and the ad-hoc ``test/bad`` route aren't in the proto
# registry, so this pair rides the JSON codec rather than ProtobufCodec.
main, sandbox = _make_json_channel_pair()
async def _bad(_payload: Any) -> None:
raise vol.Invalid("expected int", path=["options", "count"])
@@ -28,9 +28,7 @@ def test_ready_msg_type_is_stable() -> None:
def test_cli_parser_requires_name_url_and_token() -> None:
"""The CLI parser accepts the manager's argv and defaults log-level."""
parser = _build_parser()
args = parser.parse_args(
["--name", "built-in", "--url", "ws://x", "--token", "t"]
)
args = parser.parse_args(["--name", "built-in", "--url", "ws://x", "--token", "t"])
assert args.name == "built-in"
assert args.url == "ws://x"
assert args.token == "t"
@@ -11,8 +11,10 @@ import asyncio
import tempfile
from typing import Any
from hass_client._proto import sandbox_v2_pb2 as pb
from hass_client.approved_domains import ApprovedDomains
from hass_client.channel import Channel
from hass_client.codec_protobuf import ProtobufCodec
from hass_client.flow_runner import FlowRunner
from hass_client.service_mirror import ServiceMirror
import pytest
@@ -41,8 +43,12 @@ def _make_channel_pair() -> tuple[Channel, Channel]:
reader_a = asyncio.StreamReader()
reader_b = asyncio.StreamReader()
return (
Channel(reader_a, _LoopbackWriter(reader_b), name="main"), # type: ignore[arg-type]
Channel(reader_b, _LoopbackWriter(reader_a), name="sandbox"), # type: ignore[arg-type]
Channel(
reader_a, _LoopbackWriter(reader_b), name="main", codec=ProtobufCodec()
), # type: ignore[arg-type]
Channel(
reader_b, _LoopbackWriter(reader_a), name="sandbox", codec=ProtobufCodec()
), # type: ignore[arg-type]
)
@@ -76,11 +82,11 @@ async def test_register_service_pushes_to_main(
) -> None:
"""An approved-domain service registration becomes one push to main."""
main, sandbox = channels
register_calls: list[dict[str, Any]] = []
register_calls: list[pb.RegisterService] = []
async def _on_register(payload: dict[str, Any]) -> dict[str, bool]:
register_calls.append(payload)
return {"ok": True, "installed": True}
async def _on_register(msg: pb.RegisterService) -> pb.RegisterServiceResult:
register_calls.append(msg)
return pb.RegisterServiceResult(ok=True, installed=True)
main.register("sandbox_v2/register_service", _on_register)
main.start()
@@ -103,9 +109,9 @@ async def test_register_service_pushes_to_main(
await _wait_until(lambda: bool(register_calls))
assert len(register_calls) == 1
assert register_calls[0]["domain"] == "phase6_demo"
assert register_calls[0]["service"] == "do_thing"
assert register_calls[0]["supports_response"] == "none"
assert register_calls[0].domain == "phase6_demo"
assert register_calls[0].service == "do_thing"
assert register_calls[0].supports_response == "none"
await mirror.async_stop()
@@ -115,11 +121,11 @@ async def test_unapproved_domain_is_rejected(
) -> None:
"""A service for an un-approved domain never reaches main."""
main, sandbox = channels
register_calls: list[dict[str, Any]] = []
register_calls: list[pb.RegisterService] = []
async def _on_register(payload: dict[str, Any]) -> dict[str, bool]:
register_calls.append(payload)
return {"ok": True, "installed": True}
async def _on_register(msg: pb.RegisterService) -> pb.RegisterServiceResult:
register_calls.append(msg)
return pb.RegisterServiceResult(ok=True, installed=True)
main.register("sandbox_v2/register_service", _on_register)
main.start()
@@ -152,16 +158,18 @@ async def test_unregister_service_propagates(
) -> None:
"""Removing a mirrored service pushes ``unregister_service`` to main."""
main, sandbox = channels
register_calls: list[dict[str, Any]] = []
unregister_calls: list[dict[str, Any]] = []
register_calls: list[pb.RegisterService] = []
unregister_calls: list[pb.UnregisterService] = []
async def _on_register(payload: dict[str, Any]) -> dict[str, bool]:
register_calls.append(payload)
return {"ok": True, "installed": True}
async def _on_register(msg: pb.RegisterService) -> pb.RegisterServiceResult:
register_calls.append(msg)
return pb.RegisterServiceResult(ok=True, installed=True)
async def _on_unregister(payload: dict[str, Any]) -> dict[str, bool]:
unregister_calls.append(payload)
return {"ok": True, "removed": True}
async def _on_unregister(
msg: pb.UnregisterService,
) -> pb.UnregisterServiceResult:
unregister_calls.append(msg)
return pb.UnregisterServiceResult(ok=True, removed=True)
main.register("sandbox_v2/register_service", _on_register)
main.register("sandbox_v2/unregister_service", _on_unregister)
@@ -182,7 +190,7 @@ async def test_unregister_service_propagates(
await _wait_until(lambda: bool(unregister_calls))
assert len(unregister_calls) == 1
assert unregister_calls[0]["domain"] == "phase6_demo"
assert unregister_calls[0]["service"] == "go"
assert unregister_calls[0].domain == "phase6_demo"
assert unregister_calls[0].service == "go"
await mirror.async_stop()
+36 -33
View File
@@ -16,7 +16,10 @@ These tests exercise:
import asyncio
from typing import Any
from hass_client._proto import sandbox_v2_pb2 as pb
from hass_client.channel import Channel
from hass_client.codec_protobuf import ProtobufCodec
from hass_client.messages import struct_to_dict
from hass_client.protocol import MSG_SHUTDOWN, MSG_STORE_LOAD, MSG_STORE_SAVE
from hass_client.sandbox import SandboxRuntime
import pytest
@@ -49,8 +52,12 @@ def _make_channel_pair() -> tuple[Channel, Channel]:
reader_a = asyncio.StreamReader()
reader_b = asyncio.StreamReader()
return (
Channel(reader_a, _LoopbackWriter(reader_b), name="main"), # type: ignore[arg-type]
Channel(reader_b, _LoopbackWriter(reader_a), name="sandbox"), # type: ignore[arg-type]
Channel(
reader_a, _LoopbackWriter(reader_b), name="main", codec=ProtobufCodec()
), # type: ignore[arg-type]
Channel(
reader_b, _LoopbackWriter(reader_a), name="sandbox", codec=ProtobufCodec()
), # type: ignore[arg-type]
)
@@ -93,7 +100,7 @@ async def _runtime_pair_fixture():
runtime.request_shutdown()
try:
await asyncio.wait_for(task, timeout=2.0)
except (TimeoutError, Exception): # noqa: BLE001
except TimeoutError, Exception: # noqa: BLE001
task.cancel()
await main_channel.close()
@@ -105,14 +112,14 @@ async def test_shutdown_handler_returns_summary_and_exits(
"""``sandbox_v2/shutdown`` replies with a summary and the runtime exits 0."""
_runtime, main_channel, task = runtime_pair
result = await asyncio.wait_for(
main_channel.call(MSG_SHUTDOWN, None), timeout=5.0
)
result = await asyncio.wait_for(main_channel.call(MSG_SHUTDOWN, None), timeout=5.0)
assert result["ok"] is True
assert result["unloaded"] == 0
# No entries → no live RestoreEntity → restore_state stays None.
assert result["restore_state"] is None
assert result.ok is True
assert result.unloaded == 0
# No entries → no live RestoreEntity → restore_state stays unset.
# Proto-forced: old ``result["restore_state"] is None`` becomes a
# presence check on the optional field.
assert not result.HasField("restore_state")
# The runtime sets its shutdown event right after replying — wait for
# ``run()`` to return on its own; no SIGTERM should be needed.
@@ -143,13 +150,11 @@ async def test_shutdown_returns_restore_state_payload(
last_seen=dt_util.utcnow(),
)
reply = await asyncio.wait_for(
main_channel.call(MSG_SHUTDOWN, None), timeout=5.0
)
reply = await asyncio.wait_for(main_channel.call(MSG_SHUTDOWN, None), timeout=5.0)
assert reply["ok"] is True
restore_payload = reply["restore_state"]
assert isinstance(restore_payload, dict)
assert reply.ok is True
assert reply.HasField("restore_state")
restore_payload = struct_to_dict(reply.restore_state)
assert restore_payload["version"] == restore_state.STORAGE_VERSION
assert restore_payload["key"] == restore_state.STORAGE_KEY
entity_ids = [item["state"]["entity_id"] for item in restore_payload["data"]]
@@ -177,10 +182,8 @@ async def test_shutdown_fires_final_write_event(
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_FINAL_WRITE, _on_final_write)
reply = await asyncio.wait_for(
main_channel.call(MSG_SHUTDOWN, None), timeout=5.0
)
assert reply["ok"] is True
reply = await asyncio.wait_for(main_channel.call(MSG_SHUTDOWN, None), timeout=5.0)
assert reply.ok is True
assert len(fired) == 1
capsys.readouterr()
@@ -200,10 +203,11 @@ async def test_shutdown_flushes_pending_delay_save(
runtime, main_channel, _task = runtime_pair
hass = runtime._flow_runner.hass # noqa: SLF001
saves: list[dict[str, Any]] = []
saves: list[pb.StoreSave] = []
async def _on_store_save(payload: dict[str, Any]) -> None:
saves.append(payload)
async def _on_store_save(msg: pb.StoreSave) -> pb.StoreSaveResult:
saves.append(msg)
return pb.StoreSaveResult(ok=True)
main_channel.register(MSG_STORE_SAVE, _on_store_save)
@@ -214,15 +218,13 @@ async def test_shutdown_flushes_pending_delay_save(
store = _storage.Store(hass, 1, "phase12_test")
store.async_delay_save(lambda: {"pending": True}, 3600)
reply = await asyncio.wait_for(
main_channel.call(MSG_SHUTDOWN, None), timeout=5.0
)
assert reply["ok"] is True
reply = await asyncio.wait_for(main_channel.call(MSG_SHUTDOWN, None), timeout=5.0)
assert reply.ok is True
save_keys = [save["key"] for save in saves]
save_keys = [save.key for save in saves]
assert "phase12_test" in save_keys
saved = next(save for save in saves if save["key"] == "phase12_test")
assert saved["data"]["data"] == {"pending": True}
saved = next(save for save in saves if save.key == "phase12_test")
assert struct_to_dict(saved.data)["data"] == {"pending": True}
capsys.readouterr()
@@ -233,9 +235,10 @@ async def test_run_warm_loads_restore_state_on_startup(
main_channel, sandbox_channel = _make_channel_pair()
load_calls: list[str] = []
async def _on_load(payload: dict[str, Any]) -> dict[str, Any] | None:
load_calls.append(payload["key"])
return None
async def _on_load(msg: pb.StoreLoad) -> pb.StoreLoadResult:
load_calls.append(msg.key)
# Empty result = cache miss (old ``return None``).
return pb.StoreLoadResult()
main_channel.register(MSG_STORE_LOAD, _on_load)
main_channel.start()
@@ -10,6 +10,8 @@ shape and the channel pair's call/response round-trip.
import asyncio
from hass_client._proto import sandbox_v2_pb2 as pb
from hass_client.messages import struct_to_dict
from hass_client.testing._inproc import _LoopbackWriter, make_inproc_channel_pair
@@ -40,20 +42,28 @@ async def test_loopback_writer_drain_is_noop() -> None:
async def test_make_inproc_channel_pair_round_trips_a_call() -> None:
"""The two channels can call each other end-to-end without a process boundary."""
"""The two channels can call each other end-to-end without a process boundary.
Uses a real registered message type (``store_load``) since the pair now
speaks protobuf the handler echoes the request key back so the
round-trip is verified end-to-end.
"""
mgr_channel, rt_channel = make_inproc_channel_pair(group="built-in")
async def handler(payload: object) -> dict[str, str]:
return {"echo": "ok", "saw": str(payload)}
async def handler(payload: pb.StoreLoad) -> pb.StoreLoadResult:
result = pb.StoreLoadResult()
result.data.update({"echo": "ok", "saw": payload.key})
return result
rt_channel.register("test/echo", handler)
rt_channel.register("sandbox_v2/store_load", handler)
rt_channel.start()
mgr_channel.start()
try:
result = await asyncio.wait_for(
mgr_channel.call("test/echo", "hi"), timeout=2.0
mgr_channel.call("sandbox_v2/store_load", pb.StoreLoad(key="hi")),
timeout=2.0,
)
assert result == {"echo": "ok", "saw": "hi"}
assert struct_to_dict(result.data) == {"echo": "ok", "saw": "hi"}
finally:
await mgr_channel.close()
await rt_channel.close()
+38
View File
@@ -0,0 +1,38 @@
#!/usr/bin/env bash
# Drift guard for the checked-in sandbox_v2 protobuf gencode.
#
# Regenerates the _pb2 mirrors from sandbox_v2.proto (via generate.sh, which
# bootstraps its own throwaway venv — grpcio-tools is deliberately NOT a
# project dependency, since it would bump protobuf past the pinned 6.32.0)
# and fails if the result differs from what is checked in.
#
# Degrades gracefully: if `uv` (needed to build the isolated protoc venv) is
# not on PATH, it skips with a notice and exits 0 rather than blocking the
# commit. This is why it is wired as a MANUAL-stage prek hook
# (`prek run --hook-stage manual sandbox-v2-proto-drift`) / dedicated CI lane,
# not an every-commit hook.
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"
cd "${REPO_ROOT}"
if ! command -v uv >/dev/null 2>&1; then
echo "sandbox_v2 proto drift guard: 'uv' not found — skipping (install uv to run)."
exit 0
fi
HA_DEST="homeassistant/components/sandbox_v2/_proto"
CLIENT_DEST="sandbox_v2/hass_client/hass_client/_proto"
bash "${SCRIPT_DIR}/generate.sh" >/dev/null
if ! git diff --exit-code -- "${HA_DEST}" "${CLIENT_DEST}"; then
echo
echo "ERROR: checked-in protobuf gencode is out of date with sandbox_v2.proto."
echo "Run 'sandbox_v2/proto/generate.sh' and commit the regenerated _pb2 files."
exit 1
fi
echo "sandbox_v2 proto drift guard: gencode matches sandbox_v2.proto."
+49
View File
@@ -0,0 +1,49 @@
#!/usr/bin/env bash
# Regenerate the checked-in protobuf gencode for both mirrors.
#
# Core has no build-time protoc and grpcio-tools is NOT a project dependency
# (installing it into the main venv would bump protobuf past the pinned
# 6.32.0). So this script bootstraps a throwaway, isolated venv pinned to the
# runtime's protobuf and generates into both no-cross-import mirrors:
#
# homeassistant/components/sandbox_v2/_proto/sandbox_v2_pb2.py(+.pyi)
# sandbox_v2/hass_client/hass_client/_proto/sandbox_v2_pb2.py(+.pyi)
#
# Usage (from the repo root): sandbox_v2/proto/generate.sh
#
# After running, `git diff --exit-code` the two _pb2 paths must be clean — a
# dirty diff means the checked-in gencode drifted from the .proto.
set -euo pipefail
# Resolve the repo root from this script's location so it works from anywhere.
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"
cd "${REPO_ROOT}"
PROTO_DIR="sandbox_v2/proto"
HA_DEST="homeassistant/components/sandbox_v2/_proto"
CLIENT_DEST="sandbox_v2/hass_client/hass_client/_proto"
# pinned to match homeassistant/package_constraints.txt; grpcio-tools==1.80.0
# (resolved by uv) emits gencode requiring protobuf >= 6.31.1, satisfied here.
PROTOBUF_PIN="protobuf==6.32.0"
VENV_DIR="$(mktemp -d -t sandbox_v2_protogen_XXXXXX)"
trap 'rm -rf "${VENV_DIR}"' EXIT
echo "Bootstrapping isolated protogen venv at ${VENV_DIR} ..."
uv venv "${VENV_DIR}" --python 3.14 >/dev/null
uv pip install --python "${VENV_DIR}" "${PROTOBUF_PIN}" grpcio-tools mypy-protobuf >/dev/null
for DEST in "${HA_DEST}" "${CLIENT_DEST}"; do
mkdir -p "${DEST}"
touch "${DEST}/__init__.py"
echo "Generating into ${DEST} ..."
"${VENV_DIR}/bin/python" -m grpc_tools.protoc \
-I "${PROTO_DIR}" \
--python_out="${DEST}" \
--pyi_out="${DEST}" \
"${PROTO_DIR}/sandbox_v2.proto"
done
echo "Done. Verify with: git diff --exit-code ${HA_DEST} ${CLIENT_DEST}"
+322
View File
@@ -0,0 +1,322 @@
// Sandbox v2 control-channel wire protocol.
//
// Single source of truth for the protobuf messages exchanged between the HA
// Core integration (`homeassistant/components/sandbox_v2`) and the sandbox
// runtime (`hass_client`). The generated `_pb2.py` + `_pb2.pyi` are checked
// into BOTH mirrors (the two sides never cross-import); regenerate with
// `sandbox_v2/proto/generate.sh` after editing this file.
//
// Field conventions:
// * Genuinely dynamic payloads (service_data, target, state attributes,
// capabilities, flow `data`/`errors`/`context`, the wrapped Store
// envelope) cross as `google.protobuf.Struct`; the serialized
// voluptuous schema (a list of field dicts) crosses as
// `google.protobuf.ListValue`. Everything else is an explicit field.
// * Context security: only `context_id` ever crosses from the sandbox.
// `parent_id` / `user_id` are NEVER on the wire main resolves the id
// to its own authoritative Context at dispatch time.
syntax = "proto3";
package sandbox_v2;
import "google/protobuf/struct.proto";
// --- Envelope -------------------------------------------------------------
// One wire message. `type` keeps the existing string-keyed dispatch and is
// set on responses too, so a stateless codec can look up the result class
// from `type` on both encode and decode.
message Frame {
uint32 id = 1; // 0 = push (no reply); >0 = call / response
string type = 2; // e.g. "sandbox_v2/register_entity"
oneof body {
bytes request = 3; // serialized request message (call or push)
Response response = 4; // response to a call we received
}
}
message Response {
bool ok = 1;
bytes result = 2; // serialized result message (set when ok)
Error error = 3; // set when ok = false
}
// Carries fidelity #7's structured voluptuous data natively so the peer can
// rebuild the original vol.Invalid / vol.MultipleInvalid with its path.
message Error {
string message = 1;
string type = 2; // exception class name
repeated InvalidError invalid = 3; // voluptuous error path(s)
bool multiple = 4; // true = MultipleInvalid, false = Invalid
}
message InvalidError {
string message = 1;
repeated string path = 2;
}
// --- Shared sub-messages --------------------------------------------------
// A two-element pair. Used for DeviceInfo identifiers / connections (a set
// of (domain, id) pairs on HA's side) and via_device (one pair).
message DevicePair {
string key = 1;
string value = 2;
}
// Mirror of HA's DeviceInfo TypedDict. Set/tuple-shaped fields become
// repeated DevicePair; entry_type is the enum's string value. Only the keys
// the entity bridge actually forwards are modelled; unset scalar strings
// (default "") are treated as absent on the main side.
message DeviceInfo {
repeated DevicePair identifiers = 1;
repeated DevicePair connections = 2;
optional DevicePair via_device = 3;
string entry_type = 4;
string name = 5;
string manufacturer = 6;
string model = 7;
string model_id = 8;
string sw_version = 9;
string hw_version = 10;
string serial_number = 11;
string suggested_area = 12;
string configuration_url = 13;
string default_name = 14;
string default_manufacturer = 15;
string default_model = 16;
string translation_key = 17;
}
// --- entry_setup / entry_unload (main -> sandbox) -------------------------
message EntrySetup {
string entry_id = 1;
string domain = 2;
string title = 3;
google.protobuf.Struct data = 4;
google.protobuf.Struct options = 5;
string source = 6;
optional string unique_id = 7;
int32 version = 8;
int32 minor_version = 9;
}
message EntrySetupResult {
bool ok = 1;
optional string reason = 2;
}
message EntryUnload {
string entry_id = 1;
}
message EntryUnloadResult {
bool ok = 1;
}
// --- call_service (main -> sandbox) ---------------------------------------
message CallService {
string domain = 1;
string service = 2;
google.protobuf.Struct target = 3; // dynamic
google.protobuf.Struct service_data = 4; // dynamic
optional string context_id = 5; // wire-safe: only the id
bool return_response = 6;
}
// Typed envelope so every call-service response goes through the same shape;
// the dynamic payload sits inside `data`.
message ServiceResponse {
google.protobuf.Struct data = 1;
}
message CallServiceResult {
optional ServiceResponse response = 1; // unset when no response returned
}
// --- shutdown / ping (main -> sandbox) ------------------------------------
message Shutdown {}
message ShutdownResult {
bool ok = 1;
int32 unloaded = 2;
optional google.protobuf.Struct restore_state = 3; // wrapped Store envelope
}
message Ping {}
message PingResult {
string pong = 1;
}
// --- ready handshake (sandbox -> main, push) ------------------------------
message Ready {}
// --- config flow (main -> sandbox) ----------------------------------------
message FlowInit {
string handler = 1;
google.protobuf.Struct context = 2; // dynamic
google.protobuf.Struct data = 3; // dynamic (initial flow data)
}
message FlowStep {
string flow_id = 1;
google.protobuf.Struct user_input = 2; // dynamic
}
message FlowAbort {
string flow_id = 1;
}
message FlowAbortResult {}
// Marshalled FlowResult. Scalar fields are explicit; the dynamic dicts
// (data / errors / context / description_placeholders) are Struct; the
// serialized voluptuous schema is a ListValue of field dicts.
message FlowResult {
string type = 1;
optional string flow_id = 2;
optional string handler = 3;
optional string step_id = 4;
optional string reason = 5;
optional string title = 6;
optional string description = 7;
optional bool last_step = 8;
optional string preview = 9;
optional int32 version = 10;
optional int32 minor_version = 11;
google.protobuf.Struct data = 12;
google.protobuf.Struct options = 13;
google.protobuf.Struct errors = 14;
google.protobuf.Struct description_placeholders = 15;
google.protobuf.Struct context = 16;
google.protobuf.ListValue data_schema = 17;
// True when a data_schema existed on the sandbox flow but could not be
// serialized main then renders a schema-less form rather than abort.
bool has_data_schema = 18;
}
// --- register_entity (sandbox -> main) ------------------------------------
// Identity: mirrors HA's homeassistant.helpers.entity.EntityDescription
// dataclass + the entity's DeviceInfo.
message EntityInfo {
message Description {
optional string name = 1;
optional string icon = 2;
optional string entity_category = 3;
optional string device_class = 4;
int32 supported_features = 5;
optional string translation_key = 6;
}
optional Description description = 1;
optional DeviceInfo device_info = 2;
}
// Runtime starting state what HA needs to surface the entity for the
// first time.
message InitialState {
optional string state = 1;
google.protobuf.Struct capabilities = 2; // dynamic
google.protobuf.Struct attributes = 3; // dynamic
}
// Outer wire message for register_entity. Sub-messages group by HA's own
// organization: EntityInfo = identity; InitialState = runtime start.
message EntityDescription {
string entry_id = 1;
string domain = 2;
string sandbox_entity_id = 3;
optional string unique_id = 4;
bool has_entity_name = 5;
EntityInfo info = 6;
InitialState initial = 7;
}
message RegisterEntityResult {
string entity_id = 1;
}
message UnregisterEntity {
string sandbox_entity_id = 1;
}
message UnregisterEntityResult {
bool ok = 1;
}
// --- state_changed (sandbox -> main, push) --------------------------------
// Flattened the old nested `new_state` wrapper is gone. Context never
// crosses with parent_id / user_id; the id alone is wire-safe.
message StateChanged {
string sandbox_entity_id = 1;
optional string state = 2;
google.protobuf.Struct attributes = 3; // dynamic
optional string context_id = 4;
}
// --- register_service / unregister_service (sandbox -> main) --------------
message RegisterService {
string domain = 1;
string service = 2;
string supports_response = 3; // "none" / "optional" / "only"
google.protobuf.ListValue schema = 4; // serialized voluptuous, empty = none
}
message RegisterServiceResult {
bool ok = 1;
bool installed = 2;
}
message UnregisterService {
string domain = 1;
string service = 2;
}
message UnregisterServiceResult {
bool ok = 1;
bool removed = 2;
}
// --- fire_event (sandbox -> main, push) -----------------------------------
message FireEvent {
string event_type = 1;
google.protobuf.Struct event_data = 2; // dynamic
optional string context_id = 3;
}
// --- store_load / store_save / store_remove (sandbox -> main) -------------
message StoreLoad {
string key = 1;
}
message StoreLoadResult {
optional google.protobuf.Struct data = 1; // wrapped envelope, unset = miss
}
message StoreSave {
string key = 1;
google.protobuf.Struct data = 2; // wrapped envelope
}
message StoreSaveResult {
bool ok = 1;
}
message StoreRemove {
string key = 1;
}
message StoreRemoveResult {
bool ok = 1;
}
+12 -3
View File
@@ -13,7 +13,8 @@ Provides:
import asyncio
from homeassistant.components.sandbox_v2.channel import Channel
from homeassistant.components.sandbox_v2.channel import Channel, JsonCodec
from homeassistant.components.sandbox_v2.codec_protobuf import ProtobufCodec
class _LoopbackWriter:
@@ -47,12 +48,18 @@ def make_channel_pair(
name_b: str = "b",
max_inflight_a: int | None = None,
max_inflight_b: int | None = None,
use_json: bool = False,
) -> tuple[Channel, Channel]:
"""Return two channels connected to each other in-memory.
``max_inflight_a`` / ``max_inflight_b`` override the per-side
handler concurrency cap when set; otherwise the channel's default
applies. Useful for exercising the bounded-semaphore path.
The pair speaks protobuf by default (production parity, so real
handlers receive typed messages). ``use_json=True`` falls back to the
registry-free :class:`JsonCodec` for channel-core tests that drive
synthetic message types with plain dict payloads.
"""
reader_a = asyncio.StreamReader()
reader_b = asyncio.StreamReader()
@@ -64,8 +71,10 @@ def make_channel_pair(
kwargs_b: dict[str, int] = (
{"max_inflight": max_inflight_b} if max_inflight_b is not None else {}
)
channel_a = Channel(reader_a, writer_a, name=name_a, **kwargs_a) # type: ignore[arg-type]
channel_b = Channel(reader_b, writer_b, name=name_b, **kwargs_b) # type: ignore[arg-type]
codec_a = JsonCodec() if use_json else ProtobufCodec()
codec_b = JsonCodec() if use_json else ProtobufCodec()
channel_a = Channel(reader_a, writer_a, name=name_a, codec=codec_a, **kwargs_a) # type: ignore[arg-type]
channel_b = Channel(reader_b, writer_b, name=name_b, codec=codec_b, **kwargs_b) # type: ignore[arg-type]
return channel_a, channel_b
+160 -166
View File
@@ -6,12 +6,17 @@ from typing import Any
import pytest
import voluptuous as vol
from homeassistant.components.sandbox_v2._proto import sandbox_v2_pb2 as pb
from homeassistant.components.sandbox_v2.bridge import (
SandboxBridge,
SandboxEntityDescription,
_translate_remote_error,
)
from homeassistant.components.sandbox_v2.channel import Channel, ChannelRemoteError
from homeassistant.components.sandbox_v2.messages import (
make_entity_description,
struct_to_dict,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import STATE_ON
from homeassistant.core import HomeAssistant, callback
@@ -56,19 +61,19 @@ async def test_register_entity_creates_proxy_and_returns_entity_id(
"""A ``register_entity`` push creates a live proxy on the right domain."""
bridge, main_channel, sandbox_channel = await _wire(hass)
payload = {
"entry_id": entry.entry_id,
"domain": "light",
"sandbox_entity_id": "light.kitchen",
"unique_id": "sandbox-kitchen",
"name": "Kitchen",
"supported_features": 0,
"capabilities": {"supported_color_modes": ["onoff"]},
"initial_state": STATE_ON,
payload = make_entity_description(
entry_id=entry.entry_id,
domain="light",
sandbox_entity_id="light.kitchen",
unique_id="sandbox-kitchen",
name="Kitchen",
supported_features=0,
capabilities={"supported_color_modes": ["onoff"]},
initial_state=STATE_ON,
# Light requires color_mode when ON, so feed it through the
# initial cache to keep state_attributes from raising.
"initial_attributes": {"color_mode": "onoff"},
}
initial_attributes={"color_mode": "onoff"},
)
try:
result = await sandbox_channel.call("sandbox_v2/register_entity", payload)
@@ -76,8 +81,8 @@ async def test_register_entity_creates_proxy_and_returns_entity_id(
await main_channel.close()
await sandbox_channel.close()
assert result["entity_id"].startswith("light.")
state = hass.states.get(result["entity_id"])
assert result.entity_id.startswith("light.")
state = hass.states.get(result.entity_id)
assert state is not None
assert state.state == STATE_ON
# The bridge tracks the proxy by its sandbox-side entity_id.
@@ -101,17 +106,17 @@ async def test_register_entity_prefixes_unique_id_with_source_domain(
entry_b = MockConfigEntry(domain="demo_b", title="B")
entry_b.add_to_hass(hass)
def _payload(entry_id: str, sandbox_entity_id: str) -> dict[str, Any]:
return {
"entry_id": entry_id,
"domain": "light",
"sandbox_entity_id": sandbox_entity_id,
"unique_id": "1",
"supported_features": 0,
"capabilities": {"supported_color_modes": ["onoff"]},
"initial_state": STATE_ON,
"initial_attributes": {"color_mode": "onoff"},
}
def _payload(entry_id: str, sandbox_entity_id: str) -> pb.EntityDescription:
return make_entity_description(
entry_id=entry_id,
domain="light",
sandbox_entity_id=sandbox_entity_id,
unique_id="1",
supported_features=0,
capabilities={"supported_color_modes": ["onoff"]},
initial_state=STATE_ON,
initial_attributes={"color_mode": "onoff"},
)
try:
result_a = await sandbox_channel.call(
@@ -125,13 +130,13 @@ async def test_register_entity_prefixes_unique_id_with_source_domain(
await sandbox_channel.close()
# Both proxies landed as distinct entities.
assert result_a["entity_id"] != result_b["entity_id"]
assert hass.states.get(result_a["entity_id"]) is not None
assert hass.states.get(result_b["entity_id"]) is not None
assert result_a.entity_id != result_b.entity_id
assert hass.states.get(result_a.entity_id) is not None
assert hass.states.get(result_b.entity_id) is not None
# Registry rows carry the domain-prefixed unique_ids, not a bare "1".
assert entity_registry.async_get(result_a["entity_id"]).unique_id == "demo_a:1"
assert entity_registry.async_get(result_b["entity_id"]).unique_id == "demo_b:1"
assert entity_registry.async_get(result_a.entity_id).unique_id == "demo_a:1"
assert entity_registry.async_get(result_b.entity_id).unique_id == "demo_b:1"
async def test_register_entity_upsert_updates_name_in_place(
@@ -140,18 +145,18 @@ async def test_register_entity_upsert_updates_name_in_place(
"""A re-sent registration updates the proxy without adding a duplicate."""
bridge, main_channel, sandbox_channel = await _wire(hass)
def _payload(name: str) -> dict[str, Any]:
return {
"entry_id": entry.entry_id,
"domain": "light",
"sandbox_entity_id": "light.lamp",
"unique_id": "lamp",
"name": name,
"supported_features": 0,
"capabilities": {"supported_color_modes": ["onoff"]},
"initial_state": STATE_ON,
"initial_attributes": {"color_mode": "onoff"},
}
def _payload(name: str) -> pb.EntityDescription:
return make_entity_description(
entry_id=entry.entry_id,
domain="light",
sandbox_entity_id="light.lamp",
unique_id="lamp",
name=name,
supported_features=0,
capabilities={"supported_color_modes": ["onoff"]},
initial_state=STATE_ON,
initial_attributes={"color_mode": "onoff"},
)
try:
first = await sandbox_channel.call(
@@ -165,11 +170,11 @@ async def test_register_entity_upsert_updates_name_in_place(
await sandbox_channel.close()
# Same entity_id back, single tracked proxy — no duplicate created.
assert first["entity_id"] == second["entity_id"]
assert first.entity_id == second.entity_id
assert len(bridge._entities) == 1
proxy = bridge._entities["light.lamp"]
assert proxy._attr_name == "New Name"
state = hass.states.get(second["entity_id"])
state = hass.states.get(second.entity_id)
assert state is not None
assert state.attributes["friendly_name"] == "New Name"
@@ -182,22 +187,22 @@ async def test_register_entity_upsert_refreshes_device(
"""A re-sent registration with new device_info updates the device entry."""
bridge, main_channel, sandbox_channel = await _wire(hass)
def _payload(sw_version: str) -> dict[str, Any]:
return {
"entry_id": entry.entry_id,
"domain": "light",
"sandbox_entity_id": "light.lamp",
"unique_id": "lamp",
"supported_features": 0,
"capabilities": {"supported_color_modes": ["onoff"]},
"initial_state": STATE_ON,
"initial_attributes": {"color_mode": "onoff"},
"device_info": {
def _payload(sw_version: str) -> pb.EntityDescription:
return make_entity_description(
entry_id=entry.entry_id,
domain="light",
sandbox_entity_id="light.lamp",
unique_id="lamp",
supported_features=0,
capabilities={"supported_color_modes": ["onoff"]},
initial_state=STATE_ON,
initial_attributes={"color_mode": "onoff"},
device_info={
"identifiers": [["demo", "dev-1"]],
"name": "Lamp Device",
"sw_version": sw_version,
},
}
)
try:
await sandbox_channel.call("sandbox_v2/register_entity", _payload("1.0"))
@@ -223,37 +228,30 @@ async def test_state_changed_push_updates_proxy(
"""A subsequent ``state_changed`` push updates the proxy's cache."""
bridge, main_channel, sandbox_channel = await _wire(hass)
register = {
"entry_id": entry.entry_id,
"domain": "light",
"sandbox_entity_id": "light.lamp",
"unique_id": "sandbox-lamp",
"supported_features": 0,
register = make_entity_description(
entry_id=entry.entry_id,
domain="light",
sandbox_entity_id="light.lamp",
unique_id="sandbox-lamp",
supported_features=0,
# Brightness color mode so the light surfaces ``brightness`` as
# a first-class attribute when on.
"capabilities": {"supported_color_modes": ["brightness"]},
"initial_state": "off",
"initial_attributes": {},
}
capabilities={"supported_color_modes": ["brightness"]},
initial_state="off",
initial_attributes={},
)
try:
result = await sandbox_channel.call("sandbox_v2/register_entity", register)
await sandbox_channel.push(
"sandbox_v2/state_changed",
{
"sandbox_entity_id": "light.lamp",
"new_state": {
"state": STATE_ON,
"attributes": {"brightness": 250, "color_mode": "brightness"},
},
},
)
state_changed = pb.StateChanged(sandbox_entity_id="light.lamp", state=STATE_ON)
state_changed.attributes.update({"brightness": 250, "color_mode": "brightness"})
await sandbox_channel.push("sandbox_v2/state_changed", state_changed)
# Give the push handler a tick to land.
for _ in range(20):
state = hass.states.get(result["entity_id"])
state = hass.states.get(result.entity_id)
if state is not None and state.state == STATE_ON:
break
await asyncio.sleep(0)
state = hass.states.get(result["entity_id"])
state = hass.states.get(result.entity_id)
finally:
await main_channel.close()
await sandbox_channel.close()
@@ -271,30 +269,30 @@ async def test_proxy_method_translates_to_call_service(
) -> None:
"""Calling ``light.turn_on`` on a proxy fires one ``call_service`` RPC."""
_bridge, main_channel, sandbox_channel = await _wire(hass)
calls: list[dict[str, Any]] = []
calls: list[pb.CallService] = []
async def _on_call_service(payload: dict[str, Any]) -> Any:
async def _on_call_service(payload: pb.CallService) -> Any:
calls.append(payload)
return None
sandbox_channel.register("sandbox_v2/call_service", _on_call_service)
register = {
"entry_id": entry.entry_id,
"domain": "light",
"sandbox_entity_id": "light.bedroom",
"unique_id": "sandbox-bedroom",
"supported_features": 0,
"capabilities": {"supported_color_modes": ["onoff"]},
"initial_state": "off",
"initial_attributes": {},
}
register = make_entity_description(
entry_id=entry.entry_id,
domain="light",
sandbox_entity_id="light.bedroom",
unique_id="sandbox-bedroom",
supported_features=0,
capabilities={"supported_color_modes": ["onoff"]},
initial_state="off",
initial_attributes={},
)
try:
result = await sandbox_channel.call("sandbox_v2/register_entity", register)
await hass.services.async_call(
"light",
"turn_on",
{"entity_id": result["entity_id"]},
{"entity_id": result.entity_id},
blocking=True,
)
finally:
@@ -302,9 +300,9 @@ async def test_proxy_method_translates_to_call_service(
await sandbox_channel.close()
assert len(calls) == 1
assert calls[0]["domain"] == "light"
assert calls[0]["service"] == "turn_on"
assert calls[0]["target"] == {"entity_id": ["light.bedroom"]}
assert calls[0].domain == "light"
assert calls[0].service == "turn_on"
assert struct_to_dict(calls[0].target) == {"entity_id": ["light.bedroom"]}
async def test_proxy_method_batches_concurrent_calls(
@@ -312,9 +310,9 @@ async def test_proxy_method_batches_concurrent_calls(
) -> None:
"""Many entities targeted in one tick coalesce into one ``call_service``."""
bridge, main_channel, sandbox_channel = await _wire(hass)
calls: list[dict[str, Any]] = []
calls: list[pb.CallService] = []
async def _on_call_service(payload: dict[str, Any]) -> Any:
async def _on_call_service(payload: pb.CallService) -> Any:
calls.append(payload)
return None
@@ -323,16 +321,16 @@ async def test_proxy_method_batches_concurrent_calls(
sandbox_ids = []
try:
for idx in range(5):
register = {
"entry_id": entry.entry_id,
"domain": "light",
"sandbox_entity_id": f"light.bulb_{idx}",
"unique_id": f"sandbox-bulb-{idx}",
"supported_features": 0,
"capabilities": {"supported_color_modes": ["onoff"]},
"initial_state": "off",
"initial_attributes": {},
}
register = make_entity_description(
entry_id=entry.entry_id,
domain="light",
sandbox_entity_id=f"light.bulb_{idx}",
unique_id=f"sandbox-bulb-{idx}",
supported_features=0,
capabilities={"supported_color_modes": ["onoff"]},
initial_state="off",
initial_attributes={},
)
await sandbox_channel.call("sandbox_v2/register_entity", register)
sandbox_ids.append(f"light.bulb_{idx}")
@@ -348,9 +346,9 @@ async def test_proxy_method_batches_concurrent_calls(
await sandbox_channel.close()
assert len(calls) == 1
assert calls[0]["domain"] == "light"
assert calls[0]["service"] == "turn_on"
assert sorted(calls[0]["target"]["entity_id"]) == sorted(sandbox_ids)
assert calls[0].domain == "light"
assert calls[0].service == "turn_on"
assert sorted(struct_to_dict(calls[0].target)["entity_id"]) == sorted(sandbox_ids)
async def test_proxy_method_exception_translated(
@@ -371,16 +369,16 @@ async def test_proxy_method_exception_translated(
main_channel.call = _fake_call # type: ignore[method-assign]
register = {
"entry_id": entry.entry_id,
"domain": "light",
"sandbox_entity_id": "light.error",
"unique_id": "sandbox-error",
"supported_features": 0,
"capabilities": {"supported_color_modes": ["onoff"]},
"initial_state": "off",
"initial_attributes": {},
}
register = make_entity_description(
entry_id=entry.entry_id,
domain="light",
sandbox_entity_id="light.error",
unique_id="sandbox-error",
supported_features=0,
capabilities={"supported_color_modes": ["onoff"]},
initial_state="off",
initial_attributes={},
)
try:
# Register goes through the call path too, so register before we
# patch out the channel.
@@ -390,7 +388,7 @@ async def test_proxy_method_exception_translated(
# We need to test the bridge's direct call path; build the proxy by
# hand instead of going through register_entity.
description = SandboxEntityDescription.from_payload(register)
description = SandboxEntityDescription.from_proto(register)
proxy_cls = bridge._build_proxy(description).__class__
proxy = proxy_cls(bridge, description)
@@ -495,11 +493,11 @@ async def test_register_entity_for_unknown_entry_raises(
with pytest.raises(ChannelRemoteError):
await sandbox_channel.call(
"sandbox_v2/register_entity",
{
"entry_id": "no-such-entry",
"domain": "light",
"sandbox_entity_id": "light.ghost",
},
make_entity_description(
entry_id="no-such-entry",
domain="light",
sandbox_entity_id="light.ghost",
),
)
finally:
await main_channel.close()
@@ -521,31 +519,31 @@ async def test_register_entity_auto_loads_domain_component(
try:
result = await sandbox_channel.call(
"sandbox_v2/register_entity",
{
"entry_id": entry.entry_id,
"domain": "switch",
"sandbox_entity_id": "switch.outlet",
"unique_id": "sandbox-outlet",
"supported_features": 0,
"capabilities": {},
"initial_state": "off",
"initial_attributes": {},
},
make_entity_description(
entry_id=entry.entry_id,
domain="switch",
sandbox_entity_id="switch.outlet",
unique_id="sandbox-outlet",
supported_features=0,
capabilities={},
initial_state="off",
initial_attributes={},
),
)
finally:
await main_channel.close()
await sandbox_channel.close()
assert result["entity_id"].startswith("switch.")
assert result.entity_id.startswith("switch.")
assert "switch" in hass.config.components
async def test_register_service_installs_forwarder(hass: HomeAssistant) -> None:
"""A sandbox-registered service appears on main and forwards calls back."""
_bridge, main_channel, sandbox_channel = await _wire(hass)
seen_calls: list[dict[str, Any]] = []
seen_calls: list[pb.CallService] = []
async def _on_call_service(payload: dict[str, Any]) -> Any:
async def _on_call_service(payload: pb.CallService) -> Any:
seen_calls.append(payload)
return None
@@ -554,13 +552,13 @@ async def test_register_service_installs_forwarder(hass: HomeAssistant) -> None:
try:
result = await sandbox_channel.call(
"sandbox_v2/register_service",
{
"domain": "phase6_demo",
"service": "do_thing",
"supports_response": "none",
},
pb.RegisterService(
domain="phase6_demo",
service="do_thing",
supports_response="none",
),
)
assert result["installed"] is True
assert result.installed is True
assert hass.services.has_service("phase6_demo", "do_thing")
await hass.services.async_call(
@@ -571,9 +569,9 @@ async def test_register_service_installs_forwarder(hass: HomeAssistant) -> None:
await sandbox_channel.close()
assert len(seen_calls) == 1
assert seen_calls[0]["domain"] == "phase6_demo"
assert seen_calls[0]["service"] == "do_thing"
assert seen_calls[0]["service_data"] == {"foo": "bar"}
assert seen_calls[0].domain == "phase6_demo"
assert seen_calls[0].service == "do_thing"
assert struct_to_dict(seen_calls[0].service_data) == {"foo": "bar"}
async def test_register_service_skips_existing_handler(
@@ -590,17 +588,17 @@ async def test_register_service_skips_existing_handler(
try:
result = await sandbox_channel.call(
"sandbox_v2/register_service",
{
"domain": "phase6_local",
"service": "noop",
"supports_response": "none",
},
pb.RegisterService(
domain="phase6_local",
service="noop",
supports_response="none",
),
)
finally:
await main_channel.close()
await sandbox_channel.close()
assert result["installed"] is False
assert result.installed is False
# The existing handler is still in place — the bridge didn't replace it.
assert hass.services.has_service("phase6_local", "noop")
@@ -614,23 +612,23 @@ async def test_unregister_service_removes_forwarder(
try:
await sandbox_channel.call(
"sandbox_v2/register_service",
{
"domain": "phase6_demo",
"service": "stop",
"supports_response": "none",
},
pb.RegisterService(
domain="phase6_demo",
service="stop",
supports_response="none",
),
)
assert hass.services.has_service("phase6_demo", "stop")
result = await sandbox_channel.call(
"sandbox_v2/unregister_service",
{"domain": "phase6_demo", "service": "stop"},
pb.UnregisterService(domain="phase6_demo", service="stop"),
)
finally:
await main_channel.close()
await sandbox_channel.close()
assert result["removed"] is True
assert result.removed is True
assert not hass.services.has_service("phase6_demo", "stop")
@@ -647,13 +645,9 @@ async def test_fire_event_lands_on_main_bus(hass: HomeAssistant) -> None:
hass.bus.async_listen("zha_event", _on_zha)
try:
await sandbox_channel.push(
"sandbox_v2/fire_event",
{
"event_type": "zha_event",
"event_data": {"command": "on", "device_ieee": "0a:0b:0c"},
},
)
fire_event = pb.FireEvent(event_type="zha_event")
fire_event.event_data.update({"command": "on", "device_ieee": "0a:0b:0c"})
await sandbox_channel.push("sandbox_v2/fire_event", fire_event)
# Give the push handler a tick to run.
for _ in range(20):
if received:
+2 -2
View File
@@ -68,7 +68,7 @@ async def test_from_transport_round_trips() -> None:
@pytest.fixture(name="channels")
async def _channels_fixture() -> tuple:
"""Return a paired Channel + Channel, both started, both auto-cleaned."""
channel_a, channel_b = make_channel_pair()
channel_a, channel_b = make_channel_pair(use_json=True)
channel_a.start()
channel_b.start()
yield channel_a, channel_b
@@ -221,7 +221,7 @@ async def test_concurrency_cap_queues_excess_handlers() -> None:
and observes the last one waiting until one of the first two
finishes.
"""
channel_a, channel_b = make_channel_pair(max_inflight_b=2)
channel_a, channel_b = make_channel_pair(max_inflight_b=2, use_json=True)
channel_a.start()
channel_b.start()
started: list[asyncio.Event] = [asyncio.Event() for _ in range(3)]
+21 -16
View File
@@ -19,7 +19,6 @@ batcher behaviour does not.
from __future__ import annotations
import time
from typing import Any
from hass_client.testing.pytest_plugin import (
DEFAULT_GROUP,
@@ -28,6 +27,11 @@ from hass_client.testing.pytest_plugin import (
)
import pytest
from homeassistant.components.sandbox_v2._proto import sandbox_v2_pb2 as pb
from homeassistant.components.sandbox_v2.messages import (
make_entity_description,
struct_to_dict,
)
from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import STATE_OFF
from homeassistant.core import HomeAssistant
@@ -85,10 +89,11 @@ async def test_area_call_against_200_lights_completes_under_budget(
# Watch the sandbox-side call_service handler so we can prove the
# batcher coalesced N entity invocations into one RPC.
received: list[dict[str, Any]] = []
received: list[pb.CallService] = []
async def _on_call_service(payload: dict[str, Any]) -> None:
async def _on_call_service(payload: pb.CallService) -> pb.CallServiceResult:
received.append(payload)
return pb.CallServiceResult()
# Replace the runtime's handler — we want our own bookkeeping for the
# benchmark, not the runtime's normal dispatch.
@@ -100,19 +105,19 @@ async def test_area_call_against_200_lights_completes_under_budget(
# us assign it to the perf area.
entity_ids: list[str] = []
for index in range(_LIGHT_COUNT):
payload = {
"entry_id": entry.entry_id,
"domain": "light",
"sandbox_entity_id": f"light.bench_{index:03d}",
"unique_id": f"bench-{index:03d}",
"name": f"Bench {index:03d}",
"supported_features": 0,
"capabilities": {"supported_color_modes": ["onoff"]},
"initial_state": STATE_OFF,
"initial_attributes": {"color_mode": "onoff"},
}
payload = make_entity_description(
entry_id=entry.entry_id,
domain="light",
sandbox_entity_id=f"light.bench_{index:03d}",
unique_id=f"bench-{index:03d}",
name=f"Bench {index:03d}",
supported_features=0,
capabilities={"supported_color_modes": ["onoff"]},
initial_state=STATE_OFF,
initial_attributes={"color_mode": "onoff"},
)
result = await runtime_channel.call("sandbox_v2/register_entity", payload)
entity_id = result["entity_id"]
entity_id = result.entity_id
entity_ids.append(entity_id)
entity_registry.async_update_entity(entity_id, area_id=area.id)
@@ -145,7 +150,7 @@ async def test_area_call_against_200_lights_completes_under_budget(
assert 1 <= len(received) <= 2, received
flattened: list[str] = []
for payload in received:
targets = payload["target"]["entity_id"]
targets = struct_to_dict(payload.target)["entity_id"]
flattened.extend(targets if isinstance(targets, list) else [targets])
assert sorted(flattened) == sorted(entity_ids)
@@ -24,8 +24,13 @@ from typing import Any
import pytest
from homeassistant.components.sandbox_v2._proto import sandbox_v2_pb2 as pb
from homeassistant.components.sandbox_v2.bridge import SandboxBridge
from homeassistant.components.sandbox_v2.channel import Channel
from homeassistant.components.sandbox_v2.messages import (
make_entity_description,
struct_to_dict,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
@@ -379,39 +384,34 @@ async def test_phase13_proxy_smoke(
"""Each Phase-13 proxy registers, accepts state, and translates a method."""
bridge, main_channel, sandbox_channel = await _wire(hass)
calls: list[dict[str, Any]] = []
calls: list[pb.CallService] = []
async def _on_call_service(payload: dict[str, Any]) -> Any:
async def _on_call_service(payload: pb.CallService) -> Any:
calls.append(payload)
return None
sandbox_channel.register("sandbox_v2/call_service", _on_call_service)
sandbox_entity_id = f"{domain}.synthetic"
payload = {
"entry_id": entry.entry_id,
"domain": domain,
"sandbox_entity_id": sandbox_entity_id,
"unique_id": f"sandbox-{domain}",
"supported_features": register_extras.get("supported_features", 0),
"capabilities": register_extras.get("capabilities", {}),
"initial_state": state_value,
"initial_attributes": dict(state_attrs),
}
payload = make_entity_description(
entry_id=entry.entry_id,
domain=domain,
sandbox_entity_id=sandbox_entity_id,
unique_id=f"sandbox-{domain}",
supported_features=register_extras.get("supported_features", 0),
capabilities=register_extras.get("capabilities", {}),
initial_state=state_value,
initial_attributes=dict(state_attrs),
)
try:
result = await sandbox_channel.call("sandbox_v2/register_entity", payload)
# State must round-trip through the cache.
await sandbox_channel.push(
"sandbox_v2/state_changed",
{
"sandbox_entity_id": sandbox_entity_id,
"new_state": {
"state": state_value,
"attributes": dict(state_attrs),
},
},
state_changed = pb.StateChanged(
sandbox_entity_id=sandbox_entity_id, state=state_value
)
state_changed.attributes.update(dict(state_attrs))
await sandbox_channel.push("sandbox_v2/state_changed", state_changed)
# Let the state push run.
for _ in range(20):
await asyncio.sleep(0)
@@ -423,12 +423,12 @@ async def test_phase13_proxy_smoke(
await main_channel.close()
await sandbox_channel.close()
assert result["entity_id"].startswith(f"{domain}.")
state = hass.states.get(result["entity_id"])
assert result.entity_id.startswith(f"{domain}.")
state = hass.states.get(result.entity_id)
assert state is not None, f"state machine missing entry for {domain}"
if method_name is not None:
assert len(calls) == 1, f"{domain}: expected one call_service RPC"
assert calls[0]["domain"] == domain
assert calls[0]["service"] == expected_service
assert calls[0]["target"] == {"entity_id": [sandbox_entity_id]}
assert calls[0].domain == domain
assert calls[0].service == expected_service
assert struct_to_dict(calls[0].target) == {"entity_id": [sandbox_entity_id]}
+61 -62
View File
@@ -22,9 +22,11 @@ import voluptuous_serialize
from homeassistant import data_entry_flow
from homeassistant.components.sandbox_v2 import schema_bridge
from homeassistant.components.sandbox_v2._proto import sandbox_v2_pb2 as pb
from homeassistant.components.sandbox_v2.bridge import SandboxBridge
from homeassistant.components.sandbox_v2.channel import Channel
from homeassistant.components.sandbox_v2.manager import SandboxManager
from homeassistant.components.sandbox_v2.messages import struct_to_dict
from homeassistant.components.sandbox_v2.router import SandboxFlowRouter
from homeassistant.components.sandbox_v2.schema_bridge import reconstruct_schema
from homeassistant.config_entries import SOURCE_USER, ConfigEntryState
@@ -40,11 +42,11 @@ from tests.common import MockConfigEntry, MockModule, mock_integration
class _SandboxStub:
"""Tiny script-driven sandbox dispatcher for proxy-flow tests."""
def __init__(self, responses: list[dict[str, Any]]) -> None:
def __init__(self, responses: list[pb.FlowResult]) -> None:
self._responses = responses
self.init_calls: list[dict[str, Any]] = []
self.step_calls: list[dict[str, Any]] = []
self.unload_calls: list[dict[str, Any]] = []
self.init_calls: list[pb.FlowInit] = []
self.step_calls: list[pb.FlowStep] = []
self.unload_calls: list[pb.EntryUnload] = []
def attach(self, channel: Channel) -> None:
channel.register("sandbox_v2/flow_init", self._flow_init)
@@ -53,31 +55,31 @@ class _SandboxStub:
channel.register("sandbox_v2/entry_setup", self._entry_setup)
channel.register("sandbox_v2/entry_unload", self._entry_unload)
async def _flow_init(self, payload: dict[str, Any]) -> dict[str, Any]:
async def _flow_init(self, payload: pb.FlowInit) -> pb.FlowResult:
self.init_calls.append(payload)
return self._pop()
async def _flow_step(self, payload: dict[str, Any]) -> dict[str, Any]:
async def _flow_step(self, payload: pb.FlowStep) -> pb.FlowResult:
self.step_calls.append(payload)
return self._pop()
async def _flow_abort(self, _payload: dict[str, Any]) -> dict[str, Any]:
return {}
async def _flow_abort(self, _payload: pb.FlowAbort) -> pb.FlowAbortResult:
return pb.FlowAbortResult()
async def _entry_setup(self, _payload: dict[str, Any]) -> dict[str, Any]:
return {"ok": True}
async def _entry_setup(self, _payload: pb.EntrySetup) -> pb.EntrySetupResult:
return pb.EntrySetupResult(ok=True)
async def _entry_unload(self, payload: dict[str, Any]) -> dict[str, Any]:
async def _entry_unload(self, payload: pb.EntryUnload) -> pb.EntryUnloadResult:
self.unload_calls.append(payload)
return {"ok": True}
return pb.EntryUnloadResult(ok=True)
def _pop(self) -> dict[str, Any]:
def _pop(self) -> pb.FlowResult:
return self._responses.pop(0)
@contextlib.contextmanager
def _wired_sandbox(
manager: FakeSandboxManager, *, group: str, responses: list[dict[str, Any]]
manager: FakeSandboxManager, *, group: str, responses: list[pb.FlowResult]
) -> Iterator[_SandboxStub]:
main_channel, sandbox_channel = make_channel_pair(
name_a=f"main-{group}", name_b=f"sandbox-{group}"
@@ -216,15 +218,14 @@ async def test_flow_form_renders_reconstructed_schema(
serialized_schema = [
{"name": "host", "type": "string", "required": True},
]
responses = [
{
"type": FlowResultType.FORM.value,
"flow_id": "sandbox-flow-schema",
"handler": "phase14_schema",
"step_id": "user",
"data_schema": serialized_schema,
},
]
form = pb.FlowResult(
type=FlowResultType.FORM.value,
flow_id="sandbox-flow-schema",
handler="phase14_schema",
step_id="user",
)
form.data_schema.extend(serialized_schema)
responses = [form]
with (
_wired_sandbox(manager, group="built-in", responses=responses),
@@ -263,9 +264,9 @@ async def test_register_service_with_schema_validates_on_main(
main_channel.start()
sandbox_channel.start()
seen: list[dict[str, Any]] = []
seen: list[pb.CallService] = []
async def _on_call_service(payload: dict[str, Any]) -> Any:
async def _on_call_service(payload: pb.CallService) -> Any:
seen.append(payload)
return None
@@ -275,17 +276,18 @@ async def test_register_service_with_schema_validates_on_main(
{"name": "host", "type": "string", "required": True},
]
register_service = pb.RegisterService(
domain="phase14_svc",
service="do_thing",
supports_response="none",
)
register_service.schema.extend(schema_payload)
try:
result = await sandbox_channel.call(
"sandbox_v2/register_service",
{
"domain": "phase14_svc",
"service": "do_thing",
"supports_response": "none",
"schema": schema_payload,
},
"sandbox_v2/register_service", register_service
)
assert result["installed"] is True
assert result.installed is True
with pytest.raises(vol.Invalid):
await hass.services.async_call(
@@ -297,7 +299,7 @@ async def test_register_service_with_schema_validates_on_main(
"phase14_svc", "do_thing", {"host": "1.2.3.4"}, blocking=True
)
assert len(seen) == 1
assert seen[0]["service_data"] == {"host": "1.2.3.4"}
assert struct_to_dict(seen[0].service_data) == {"host": "1.2.3.4"}
finally:
await main_channel.close()
await sandbox_channel.close()
@@ -314,15 +316,14 @@ async def test_unique_id_propagates_to_proxy_context(
) -> None:
"""A sandbox-side ``unique_id`` is mirrored onto the proxy's context."""
mock_integration(hass, MockModule("phase14_unique"))
responses = [
{
"type": FlowResultType.FORM.value,
"flow_id": "sandbox-flow-uid",
"handler": "phase14_unique",
"step_id": "user",
"context": {"source": SOURCE_USER, "unique_id": "abc-123"},
}
]
form = pb.FlowResult(
type=FlowResultType.FORM.value,
flow_id="sandbox-flow-uid",
handler="phase14_unique",
step_id="user",
)
form.context.update({"source": SOURCE_USER, "unique_id": "abc-123"})
responses = [form]
with (
_wired_sandbox(manager, group="built-in", responses=responses),
@@ -349,24 +350,22 @@ async def test_duplicate_unique_id_aborts_second_flow(
) -> None:
"""A second flow with the same propagated unique_id aborts on main."""
mock_integration(hass, MockModule("phase14_duplicate"))
responses_a = [
{
"type": FlowResultType.FORM.value,
"flow_id": "sandbox-flow-dup-a",
"handler": "phase14_duplicate",
"step_id": "user",
"context": {"source": SOURCE_USER, "unique_id": "dup-1"},
}
]
responses_b = [
{
"type": FlowResultType.FORM.value,
"flow_id": "sandbox-flow-dup-b",
"handler": "phase14_duplicate",
"step_id": "user",
"context": {"source": SOURCE_USER, "unique_id": "dup-1"},
}
]
form_a = pb.FlowResult(
type=FlowResultType.FORM.value,
flow_id="sandbox-flow-dup-a",
handler="phase14_duplicate",
step_id="user",
)
form_a.context.update({"source": SOURCE_USER, "unique_id": "dup-1"})
responses_a = [form_a]
form_b = pb.FlowResult(
type=FlowResultType.FORM.value,
flow_id="sandbox-flow-dup-b",
handler="phase14_duplicate",
step_id="user",
)
form_b.context.update({"source": SOURCE_USER, "unique_id": "dup-1"})
responses_b = [form_b]
with (
_wired_sandbox(manager, group="built-in", responses=responses_a + responses_b),
@@ -432,7 +431,7 @@ async def test_async_unload_consults_router_for_sandboxed_entry(
assert entry.state is ConfigEntryState.NOT_LOADED
# The sandbox saw exactly one entry_unload call.
assert len(stub.unload_calls) == 1
assert stub.unload_calls[0]["entry_id"] == entry.entry_id
assert stub.unload_calls[0].entry_id == entry.entry_id
async def test_async_unload_falls_through_for_non_sandboxed_entry(
@@ -4,11 +4,13 @@ from typing import Any
import pytest
from homeassistant.components.sandbox_v2._proto import sandbox_v2_pb2 as pb
from homeassistant.components.sandbox_v2.bridge import (
SandboxBridge,
SandboxEntityDescription,
)
from homeassistant.components.sandbox_v2.channel import Channel, ChannelRemoteError
from homeassistant.components.sandbox_v2.messages import make_entity_description
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import (
@@ -51,20 +53,18 @@ def _register_payload(
sandbox_entity_id: str = "light.kitchen",
unique_id: str = "sandbox-kitchen",
device_info: dict[str, Any] | None = None,
) -> dict[str, Any]:
payload: dict[str, Any] = {
"entry_id": entry.entry_id,
"domain": "light",
"sandbox_entity_id": sandbox_entity_id,
"unique_id": unique_id,
"supported_features": 0,
"capabilities": {"supported_color_modes": ["onoff"]},
"initial_state": "on",
"initial_attributes": {"color_mode": "onoff"},
}
if device_info is not None:
payload["device_info"] = device_info
return payload
) -> pb.EntityDescription:
return make_entity_description(
entry_id=entry.entry_id,
domain="light",
sandbox_entity_id=sandbox_entity_id,
unique_id=unique_id,
supported_features=0,
capabilities={"supported_color_modes": ["onoff"]},
initial_state="on",
initial_attributes={"color_mode": "onoff"},
device_info=device_info,
)
async def test_register_entity_creates_device_entry(
@@ -98,7 +98,7 @@ async def test_register_entity_creates_device_entry(
# The DeviceEntry must be linked to the sandboxed config entry.
assert entry.entry_id in device.config_entries
# The main-side entity_registry entry has device_id set to that DeviceEntry.
er_entry = er.async_get(hass).async_get(result["entity_id"])
er_entry = er.async_get(hass).async_get(result.entity_id)
assert er_entry is not None
assert er_entry.device_id == device.id
@@ -184,7 +184,7 @@ async def test_area_assignment_propagates_to_proxy(
refreshed_device = dr.async_get(hass).async_get(device.id)
assert refreshed_device is not None
assert refreshed_device.area_id == area.id
er_entry = er.async_get(hass).async_get(result["entity_id"])
er_entry = er.async_get(hass).async_get(result.entity_id)
assert er_entry is not None
assert er_entry.device_id == device.id
@@ -207,21 +207,21 @@ async def test_invalid_device_info_surfaces_remote_error(
await sandbox_channel.close()
async def test_description_from_payload_reconstructs_typed_device_info() -> None:
"""``SandboxEntityDescription.from_payload`` rebuilds set/tuple shapes."""
description = SandboxEntityDescription.from_payload(
{
"entry_id": "abc",
"domain": "sensor",
"sandbox_entity_id": "sensor.temp",
"device_info": {
async def test_description_from_proto_reconstructs_typed_device_info() -> None:
"""``SandboxEntityDescription.from_proto`` rebuilds set/tuple shapes."""
description = SandboxEntityDescription.from_proto(
make_entity_description(
entry_id="abc",
domain="sensor",
sandbox_entity_id="sensor.temp",
device_info={
"identifiers": [["foo", "1"], ["foo", "2"]],
"connections": [["mac", "00:11:22"]],
"via_device": ["parent_domain", "parent-1"],
"entry_type": "service",
"name": "Thermo",
},
}
)
)
assert description.device_info is not None
info = description.device_info
@@ -57,7 +57,7 @@ async def test_subprocess_handshake_and_ping(manager: SandboxManager) -> None:
assert channel is not None
result = await asyncio.wait_for(channel.call("sandbox_v2/ping", None), timeout=5.0)
assert result == {"pong": "sandbox_v2"}
assert result.pong == "sandbox_v2"
await manager.async_stop("built-in")
assert sandbox.state == "stopped"
@@ -16,6 +16,7 @@ import sys
import pytest
from homeassistant.components.sandbox_v2._proto import sandbox_v2_pb2 as pb
from homeassistant.components.sandbox_v2.manager import (
SandboxConfig,
SandboxManager,
@@ -93,12 +94,13 @@ async def test_graceful_shutdown_falls_through_to_sigterm_on_timeout(
sys.executable,
"-c",
(
"import sys, time, struct, json;"
# Length-prefixed Ready push frame — the manager's
# StreamTransport reads this and flips to "running".
"body = json.dumps("
"{'type': 'sandbox_v2/ready', 'payload': None}"
").encode();"
"import sys, time, struct;"
"from hass_client._proto import sandbox_v2_pb2 as pb;"
# Length-prefixed protobuf Ready push frame — the manager's
# ProtobufCodec decodes this and flips to "running".
"frame = pb.Frame(id=0, type='sandbox_v2/ready');"
"frame.request = pb.Ready().SerializeToString();"
"body = frame.SerializeToString();"
"sys.stdout.buffer.write(struct.pack('>I', len(body)) + body);"
"sys.stdout.buffer.flush();"
# Just sleep — stdin is wired to the manager but we never read.
@@ -172,9 +174,9 @@ async def test_on_shutdown_reply_callback_is_invoked(
hass_client-side ``test_shutdown`` suite here we only pin that
the callback wiring fires.
"""
replies: list[tuple[str, dict]] = []
replies: list[tuple[str, pb.ShutdownResult]] = []
async def _on_shutdown_reply(group: str, reply: dict) -> None:
async def _on_shutdown_reply(group: str, reply: pb.ShutdownResult) -> None:
replies.append((group, reply))
def _factory(group: str) -> list[str]:
@@ -205,7 +207,8 @@ async def test_on_shutdown_reply_callback_is_invoked(
assert len(replies) == 1
group, reply = replies[0]
assert group == "built-in"
assert reply["ok"] is True
assert reply["unloaded"] == 0
assert reply.ok is True
assert reply.unloaded == 0
# No integration was loaded → no RestoreEntity → no snapshot.
assert reply["restore_state"] is None
# proto: optional field unset (was `restore_state is None`).
assert not reply.HasField("restore_state")
@@ -0,0 +1,209 @@
"""T2 transport tests: ProtobufCodec round-trips + the Context security model.
Covers the three guarantees the protobuf wire adds on top of T1:
* a frame survives an encode decode re-encode cycle byte-identically (no
field drops), including fidelity #7's structured voluptuous error data;
* :meth:`SandboxBridge._resolve_context` reuses a known Context and mints a
fresh one attributed to the sandbox system user, never carrying a
sandbox-supplied ``parent_id`` for an unseen id;
* a sandbox-emitted ``state_changed`` carrying a ``context_id`` lands on main
with a Context owned by the sandbox system user and no ``parent_id``.
"""
import asyncio
import pytest
from homeassistant.components.sandbox_v2._proto import sandbox_v2_pb2 as pb
from homeassistant.components.sandbox_v2.auth import async_get_or_create_sandbox_user
from homeassistant.components.sandbox_v2.bridge import SandboxBridge
from homeassistant.components.sandbox_v2.channel import Frame
from homeassistant.components.sandbox_v2.codec_protobuf import ProtobufCodec
from homeassistant.components.sandbox_v2.messages import (
make_entity_description,
struct_to_dict,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Context, HomeAssistant
from ._helpers import make_channel_pair
from tests.common import MockConfigEntry
@pytest.fixture(name="entry")
def _entry_fixture(hass: HomeAssistant) -> ConfigEntry:
"""A loaded light MockConfigEntry registered against ``hass``."""
entry = MockConfigEntry(
domain="light", title="Sandboxed Hue", data={"host": "1.2.3.4"}
)
entry.add_to_hass(hass)
return entry
def test_protobuf_codec_round_trip_is_byte_identical() -> None:
"""A full EntityDescription frame re-encodes byte-for-byte after a decode."""
codec = ProtobufCodec()
desc = make_entity_description(
entry_id="entry-1",
domain="light",
sandbox_entity_id="light.kitchen",
unique_id="u-1",
name="Kitchen",
has_entity_name=True,
supported_features=3,
capabilities={"supported_color_modes": ["onoff", "brightness"]},
initial_state="on",
initial_attributes={"brightness": 255, "color_mode": "brightness"},
device_info={
"identifiers": [["demo", "dev-1"]],
"name": "Lamp",
"sw_version": "1.0",
},
)
frame = Frame.call(7, "sandbox_v2/register_entity", desc)
wire1 = codec.encode(frame)
decoded = codec.decode(wire1)
wire2 = codec.encode(decoded)
assert wire1 == wire2
# And no nested field was dropped on the way through.
assert decoded.payload.info.description.name == "Kitchen"
assert decoded.payload.info.description.supported_features == 3
assert decoded.payload.initial.state == "on"
assert struct_to_dict(decoded.payload.initial.capabilities) == {
"supported_color_modes": ["onoff", "brightness"]
}
def test_protobuf_codec_round_trips_response_result() -> None:
"""A success response carries its typed result class through the codec."""
codec = ProtobufCodec()
frame = Frame.ok_response(
9,
pb.RegisterEntityResult(entity_id="light.kitchen_2"),
"sandbox_v2/register_entity",
)
decoded = codec.decode(codec.encode(frame))
assert decoded.ok is True
assert decoded.result.entity_id == "light.kitchen_2"
def test_protobuf_codec_round_trips_invalid_error_data() -> None:
"""Fidelity #7's single-Invalid structured data survives the proto wire."""
codec = ProtobufCodec()
frame = Frame.error_response(
3,
"expected int",
"Invalid",
{"kind": "invalid", "msg": "expected int", "path": ["options", "count"]},
"sandbox_v2/call_service",
)
decoded = codec.decode(codec.encode(frame))
assert decoded.ok is False
assert decoded.error == "expected int"
assert decoded.error_type == "Invalid"
assert decoded.error_data == {
"kind": "invalid",
"msg": "expected int",
"path": ["options", "count"],
}
def test_protobuf_codec_round_trips_multiple_invalid_error_data() -> None:
"""A MultipleInvalid keeps its ``multiple`` discriminator + every child."""
codec = ProtobufCodec()
error_data = {
"kind": "multiple",
"errors": [
{"kind": "invalid", "msg": "expected int", "path": ["count"]},
{"kind": "invalid", "msg": "required key", "path": ["name"]},
],
}
frame = Frame.error_response(
4, "two errors", "MultipleInvalid", error_data, "sandbox_v2/call_service"
)
decoded = codec.decode(codec.encode(frame))
assert decoded.error_type == "MultipleInvalid"
assert decoded.error_data == error_data
async def test_resolve_context_caches_known_and_mints_unknown(
hass: HomeAssistant,
) -> None:
"""A known context_id reuses its Context; an unseen one is minted safely."""
main_channel, sandbox_channel = make_channel_pair(name_a="main", name_b="sandbox")
bridge = SandboxBridge(hass, group="built-in", channel=main_channel)
user = await async_get_or_create_sandbox_user(hass, "built-in")
try:
known = Context(user_id=user.id, id="known-id")
bridge._contexts["known-id"] = known
# A known id returns the exact cached Context.
assert await bridge._resolve_context("known-id") is known
# An unseen id mints a fresh Context: the sandbox-supplied id is kept,
# but it is attributed to the sandbox system user with no parent_id.
minted = await bridge._resolve_context("fresh-id")
assert minted.id == "fresh-id"
assert minted.parent_id is None
assert minted.user_id == user.id
# And caching makes a second resolve return the same object.
assert await bridge._resolve_context("fresh-id") is minted
# No id at all → a system-user Context, still no parent_id.
anon = await bridge._resolve_context(None)
assert anon.parent_id is None
assert anon.user_id == user.id
finally:
await main_channel.close()
await sandbox_channel.close()
async def test_state_changed_context_attributed_to_sandbox_system_user(
hass: HomeAssistant, entry: ConfigEntry
) -> None:
"""A sandbox state_changed with a context_id lands owned by the system user."""
main_channel, sandbox_channel = make_channel_pair(name_a="main", name_b="sandbox")
# Constructing the bridge registers the inbound handlers on main_channel.
SandboxBridge(hass, group="built-in", channel=main_channel)
main_channel.start()
sandbox_channel.start()
desc = make_entity_description(
entry_id=entry.entry_id,
domain="light",
sandbox_entity_id="light.lamp",
unique_id="sandbox-lamp",
supported_features=0,
capabilities={"supported_color_modes": ["onoff"]},
initial_state="off",
initial_attributes={"color_mode": "onoff"},
)
try:
result = await sandbox_channel.call("sandbox_v2/register_entity", desc)
entity_id = result.entity_id
changed = pb.StateChanged(
sandbox_entity_id="light.lamp", state="on", context_id="sandbox-ctx-1"
)
changed.attributes.update({"color_mode": "onoff"})
await sandbox_channel.push("sandbox_v2/state_changed", changed)
for _ in range(200):
state = hass.states.get(entity_id)
if state is not None and state.state == "on":
break
await asyncio.sleep(0.01)
finally:
await main_channel.close()
await sandbox_channel.close()
user = await async_get_or_create_sandbox_user(hass, "built-in")
state = hass.states.get(entity_id)
assert state is not None
assert state.state == "on"
# The sandbox only sent a context_id; main owns the authoritative Context.
assert state.context.id == "sandbox-ctx-1"
assert state.context.user_id == user.id
assert state.context.parent_id is None
+57 -52
View File
@@ -3,13 +3,15 @@
import asyncio
from collections.abc import Iterator
import contextlib
from typing import Any, cast
from typing import cast
from unittest.mock import patch
import pytest
from homeassistant.components.sandbox_v2._proto import sandbox_v2_pb2 as pb
from homeassistant.components.sandbox_v2.channel import Channel
from homeassistant.components.sandbox_v2.manager import SandboxManager
from homeassistant.components.sandbox_v2.messages import struct_to_dict
from homeassistant.components.sandbox_v2.router import SandboxFlowRouter
from homeassistant.config_entries import SOURCE_USER, ConfigEntryState
from homeassistant.core import HomeAssistant
@@ -23,11 +25,11 @@ from tests.common import MockModule, mock_integration
class _SandboxStub:
"""Tiny sandbox-side dispatcher backed by a script of canned responses."""
def __init__(self, responses: list[dict[str, Any]]) -> None:
def __init__(self, responses: list[pb.FlowResult]) -> None:
self._responses = responses
self.init_calls: list[dict[str, Any]] = []
self.step_calls: list[dict[str, Any]] = []
self.abort_calls: list[dict[str, Any]] = []
self.init_calls: list[pb.FlowInit] = []
self.step_calls: list[pb.FlowStep] = []
self.abort_calls: list[pb.FlowAbort] = []
def attach(self, channel: Channel) -> None:
channel.register("sandbox_v2/flow_init", self._flow_init)
@@ -36,31 +38,31 @@ class _SandboxStub:
channel.register("sandbox_v2/entry_setup", self._entry_setup)
channel.register("sandbox_v2/entry_unload", self._entry_unload)
async def _flow_init(self, payload: dict[str, Any]) -> dict[str, Any]:
async def _flow_init(self, payload: pb.FlowInit) -> pb.FlowResult:
self.init_calls.append(payload)
return self._pop()
async def _flow_step(self, payload: dict[str, Any]) -> dict[str, Any]:
async def _flow_step(self, payload: pb.FlowStep) -> pb.FlowResult:
self.step_calls.append(payload)
return self._pop()
async def _flow_abort(self, payload: dict[str, Any]) -> dict[str, Any]:
async def _flow_abort(self, payload: pb.FlowAbort) -> pb.FlowAbortResult:
self.abort_calls.append(payload)
return {}
return pb.FlowAbortResult()
async def _entry_setup(self, _payload: dict[str, Any]) -> dict[str, Any]:
return {"ok": True}
async def _entry_setup(self, _payload: pb.EntrySetup) -> pb.EntrySetupResult:
return pb.EntrySetupResult(ok=True)
async def _entry_unload(self, _payload: dict[str, Any]) -> dict[str, Any]:
return {"ok": True}
async def _entry_unload(self, _payload: pb.EntryUnload) -> pb.EntryUnloadResult:
return pb.EntryUnloadResult(ok=True)
def _pop(self) -> dict[str, Any]:
def _pop(self) -> pb.FlowResult:
return self._responses.pop(0)
@contextlib.contextmanager
def _wired_sandbox(
manager: FakeSandboxManager, *, group: str, responses: list[dict[str, Any]]
manager: FakeSandboxManager, *, group: str, responses: list[pb.FlowResult]
) -> Iterator[_SandboxStub]:
"""Wire a sandbox stub onto a fresh in-memory channel pair."""
main_channel, sandbox_channel = make_channel_pair(
@@ -102,22 +104,23 @@ async def test_full_flow_user_to_create_entry(
) -> None:
"""A user-initiated flow that asks for input then creates an entry."""
mock_integration(hass, MockModule("test_proxy_full"))
create_entry = pb.FlowResult(
type=FlowResultType.CREATE_ENTRY.value,
flow_id="sandbox-flow-1",
handler="test_proxy_full",
title="Proxy Title",
)
create_entry.data.update({"host": "1.2.3.4"})
responses = [
# Response to flow_init — show a form
{
"type": FlowResultType.FORM.value,
"flow_id": "sandbox-flow-1",
"handler": "test_proxy_full",
"step_id": "user",
},
pb.FlowResult(
type=FlowResultType.FORM.value,
flow_id="sandbox-flow-1",
handler="test_proxy_full",
step_id="user",
),
# Response to flow_step — create the entry
{
"type": FlowResultType.CREATE_ENTRY.value,
"flow_id": "sandbox-flow-1",
"handler": "test_proxy_full",
"title": "Proxy Title",
"data": {"host": "1.2.3.4"},
},
create_entry,
]
with (
@@ -144,12 +147,13 @@ async def test_full_flow_user_to_create_entry(
assert result["title"] == "Proxy Title"
assert len(stub.init_calls) == 1
assert stub.init_calls[0]["handler"] == "test_proxy_full"
assert stub.init_calls[0]["context"]["source"] == SOURCE_USER
assert stub.init_calls[0]["data"] is None
assert stub.init_calls[0].handler == "test_proxy_full"
assert struct_to_dict(stub.init_calls[0].context)["source"] == SOURCE_USER
# proto: a USER-source init carries no `data` field (was `data is None`).
assert not stub.init_calls[0].HasField("data")
assert len(stub.step_calls) == 1
assert stub.step_calls[0]["flow_id"] == "sandbox-flow-1"
assert stub.step_calls[0]["user_input"] == {"host": "1.2.3.4"}
assert stub.step_calls[0].flow_id == "sandbox-flow-1"
assert struct_to_dict(stub.step_calls[0].user_input) == {"host": "1.2.3.4"}
# The new ConfigEntry is tagged with the sandbox group via the
# ConfigEntry.sandbox first-class field (Phase 17 — keeps the tag
@@ -167,20 +171,21 @@ async def test_form_with_errors_reshows(
) -> None:
"""A form returned with `errors` is shown as a fresh form on main."""
mock_integration(hass, MockModule("test_proxy_errors"))
reshow = pb.FlowResult(
type=FlowResultType.FORM.value,
flow_id="sandbox-flow-err",
handler="test_proxy_errors",
step_id="user",
)
reshow.errors.update({"host": "invalid_host"})
responses = [
{
"type": FlowResultType.FORM.value,
"flow_id": "sandbox-flow-err",
"handler": "test_proxy_errors",
"step_id": "user",
},
{
"type": FlowResultType.FORM.value,
"flow_id": "sandbox-flow-err",
"handler": "test_proxy_errors",
"step_id": "user",
"errors": {"host": "invalid_host"},
},
pb.FlowResult(
type=FlowResultType.FORM.value,
flow_id="sandbox-flow-err",
handler="test_proxy_errors",
step_id="user",
),
reshow,
]
with (
@@ -208,12 +213,12 @@ async def test_abort_is_propagated(
"""An ABORT from the sandbox surfaces as an abort on main."""
mock_integration(hass, MockModule("test_proxy_abort"))
responses = [
{
"type": FlowResultType.ABORT.value,
"flow_id": "sandbox-flow-abort",
"handler": "test_proxy_abort",
"reason": "already_configured",
}
pb.FlowResult(
type=FlowResultType.ABORT.value,
flow_id="sandbox-flow-abort",
handler="test_proxy_abort",
reason="already_configured",
)
]
with (
+10 -8
View File
@@ -4,7 +4,9 @@ from typing import cast
import pytest
from homeassistant.components.sandbox_v2._proto import sandbox_v2_pb2 as pb
from homeassistant.components.sandbox_v2.manager import SandboxManager
from homeassistant.components.sandbox_v2.messages import struct_to_dict
from homeassistant.components.sandbox_v2.proxy_flow import SandboxFlowProxy
from homeassistant.components.sandbox_v2.router import SandboxFlowRouter
from homeassistant.config_entries import SOURCE_USER, ConfigEntry, ConfigFlowContext
@@ -105,11 +107,11 @@ async def test_async_setup_entry_routes_to_sandbox(
main-side entry state.
"""
channel_a, channel_b = make_channel_pair()
received: list[dict[str, object]] = []
received: list[pb.EntrySetup] = []
async def _entry_setup(payload: dict[str, object]) -> dict[str, object]:
async def _entry_setup(payload: pb.EntrySetup) -> pb.EntrySetupResult:
received.append(payload)
return {"ok": True}
return pb.EntrySetupResult(ok=True)
channel_b.register("sandbox_v2/entry_setup", _entry_setup)
channel_a.start()
@@ -132,11 +134,11 @@ async def test_async_setup_entry_routes_to_sandbox(
assert result is True
assert manager.start_calls == ["built-in"]
assert len(received) == 1
assert received[0]["domain"] == "test_entry"
assert received[0]["title"] == "Test"
assert received[0].domain == "test_entry"
assert received[0].title == "Test"
# Sandbox group is carried as a first-class ConfigEntry field now;
# entry.data on the wire is exactly what the integration sees.
assert received[0]["data"] == {}
assert struct_to_dict(received[0].data) == {}
async def test_async_setup_entry_marks_setup_error_on_failure(
@@ -145,8 +147,8 @@ async def test_async_setup_entry_marks_setup_error_on_failure(
"""A sandbox refusing entry_setup propagates as SETUP_ERROR."""
channel_a, channel_b = make_channel_pair()
async def _entry_setup(_payload: dict[str, object]) -> dict[str, object]:
return {"ok": False, "reason": "boom"}
async def _entry_setup(_payload: pb.EntrySetup) -> pb.EntrySetupResult:
return pb.EntrySetupResult(ok=False, reason="boom")
channel_b.register("sandbox_v2/entry_setup", _entry_setup)
channel_a.start()
+50 -45
View File
@@ -22,8 +22,10 @@ from typing import Any
import pytest
from homeassistant.components.sandbox_v2._proto import sandbox_v2_pb2 as pb
from homeassistant.components.sandbox_v2.bridge import SandboxBridge
from homeassistant.components.sandbox_v2.channel import Channel, ChannelRemoteError
from homeassistant.components.sandbox_v2.messages import struct_to_dict
from homeassistant.core import HomeAssistant
from homeassistant.helpers.storage import STORAGE_DIR
@@ -48,26 +50,25 @@ def _store_path(hass: HomeAssistant, group: str, key: str) -> Path:
async def test_store_save_writes_to_namespaced_path(hass: HomeAssistant) -> None:
"""A save lands at ``.storage/sandbox_v2/<group>/<key>`` on main."""
_bridge, main_channel, sandbox_channel = await _wire(hass, group="built-in")
payload = {
wrapped = {
"version": 1,
"minor_version": 1,
"key": "phase8_demo",
"data": {
"version": 1,
"minor_version": 1,
"key": "phase8_demo",
"data": {"hello": "world"},
},
"data": {"hello": "world"},
}
save = pb.StoreSave(key="phase8_demo")
save.data.update(wrapped)
try:
result = await sandbox_channel.call("sandbox_v2/store_save", payload)
result = await sandbox_channel.call("sandbox_v2/store_save", save)
finally:
await main_channel.close()
await sandbox_channel.close()
assert result == {"ok": True}
assert result.ok
path = _store_path(hass, "built-in", "phase8_demo")
assert path.is_file()
# The file holds the wrapped Store payload verbatim.
assert json.loads(path.read_text(encoding="utf-8")) == payload["data"]
assert json.loads(path.read_text(encoding="utf-8")) == wrapped
async def test_store_load_returns_saved_payload(hass: HomeAssistant) -> None:
@@ -79,18 +80,18 @@ async def test_store_load_returns_saved_payload(hass: HomeAssistant) -> None:
"key": "phase8_demo",
"data": {"counter": 42},
}
save = pb.StoreSave(key="phase8_demo")
save.data.update(wrapped)
try:
await sandbox_channel.call(
"sandbox_v2/store_save", {"key": "phase8_demo", "data": wrapped}
)
await sandbox_channel.call("sandbox_v2/store_save", save)
loaded = await sandbox_channel.call(
"sandbox_v2/store_load", {"key": "phase8_demo"}
"sandbox_v2/store_load", pb.StoreLoad(key="phase8_demo")
)
finally:
await main_channel.close()
await sandbox_channel.close()
assert loaded == wrapped
assert struct_to_dict(loaded.data) == wrapped
async def test_store_load_missing_key_returns_none(hass: HomeAssistant) -> None:
@@ -98,42 +99,41 @@ async def test_store_load_missing_key_returns_none(hass: HomeAssistant) -> None:
_bridge, main_channel, sandbox_channel = await _wire(hass)
try:
loaded = await sandbox_channel.call(
"sandbox_v2/store_load", {"key": "never_saved"}
"sandbox_v2/store_load", pb.StoreLoad(key="never_saved")
)
finally:
await main_channel.close()
await sandbox_channel.close()
assert loaded is None
# proto: a missing key returns a StoreLoadResult with no `data` field set.
assert not loaded.HasField("data")
async def test_store_remove_unlinks_file(hass: HomeAssistant) -> None:
"""``store_remove`` removes the on-disk file."""
_bridge, main_channel, sandbox_channel = await _wire(hass)
save = pb.StoreSave(key="to_remove")
save.data.update(
{
"version": 1,
"minor_version": 1,
"key": "to_remove",
"data": {"x": 1},
}
)
try:
await sandbox_channel.call(
"sandbox_v2/store_save",
{
"key": "to_remove",
"data": {
"version": 1,
"minor_version": 1,
"key": "to_remove",
"data": {"x": 1},
},
},
)
await sandbox_channel.call("sandbox_v2/store_save", save)
path = _store_path(hass, "built-in", "to_remove")
assert path.is_file()
result = await sandbox_channel.call(
"sandbox_v2/store_remove", {"key": "to_remove"}
"sandbox_v2/store_remove", pb.StoreRemove(key="to_remove")
)
finally:
await main_channel.close()
await sandbox_channel.close()
assert result == {"ok": True}
assert result.ok
assert not _store_path(hass, "built-in", "to_remove").exists()
@@ -142,13 +142,13 @@ async def test_store_remove_missing_key_is_noop(hass: HomeAssistant) -> None:
_bridge, main_channel, sandbox_channel = await _wire(hass)
try:
result = await sandbox_channel.call(
"sandbox_v2/store_remove", {"key": "phantom"}
"sandbox_v2/store_remove", pb.StoreRemove(key="phantom")
)
finally:
await main_channel.close()
await sandbox_channel.close()
assert result == {"ok": True}
assert result.ok
@pytest.mark.parametrize(
@@ -166,7 +166,9 @@ async def test_store_rejects_path_traversal(hass: HomeAssistant, bad_key: str) -
_bridge, main_channel, sandbox_channel = await _wire(hass)
try:
with pytest.raises(ChannelRemoteError):
await sandbox_channel.call("sandbox_v2/store_load", {"key": bad_key})
await sandbox_channel.call(
"sandbox_v2/store_load", pb.StoreLoad(key=bad_key)
)
finally:
await main_channel.close()
await sandbox_channel.close()
@@ -177,7 +179,7 @@ async def test_store_rejects_missing_key(hass: HomeAssistant) -> None:
_bridge, main_channel, sandbox_channel = await _wire(hass)
try:
with pytest.raises(ChannelRemoteError):
await sandbox_channel.call("sandbox_v2/store_load", {})
await sandbox_channel.call("sandbox_v2/store_load", pb.StoreLoad(key=""))
finally:
await main_channel.close()
await sandbox_channel.close()
@@ -193,13 +195,13 @@ async def test_store_groups_are_isolated(hass: HomeAssistant) -> None:
"key": "shared_key",
"data": {"side": "built-in"},
}
save = pb.StoreSave(key="shared_key")
save.data.update(wrapped)
try:
await sandbox_a.call(
"sandbox_v2/store_save", {"key": "shared_key", "data": wrapped}
)
await sandbox_a.call("sandbox_v2/store_save", save)
# The custom-group bridge cannot see built-in's data.
loaded_custom = await sandbox_b.call(
"sandbox_v2/store_load", {"key": "shared_key"}
"sandbox_v2/store_load", pb.StoreLoad(key="shared_key")
)
finally:
await main_a.close()
@@ -207,7 +209,8 @@ async def test_store_groups_are_isolated(hass: HomeAssistant) -> None:
await main_b.close()
await sandbox_b.close()
assert loaded_custom is None
# proto: the custom group has no entry, so `data` is unset.
assert not loaded_custom.HasField("data")
assert _store_path(hass, "built-in", "shared_key").is_file()
assert not _store_path(hass, "custom", "shared_key").exists()
@@ -221,10 +224,10 @@ async def test_store_survives_bridge_restart(hass: HomeAssistant) -> None:
"key": "persistent",
"data": {"survives": True},
}
save = pb.StoreSave(key="persistent")
save.data.update(wrapped)
try:
await sandbox_a.call(
"sandbox_v2/store_save", {"key": "persistent", "data": wrapped}
)
await sandbox_a.call("sandbox_v2/store_save", save)
finally:
await main_a.close()
await sandbox_a.close()
@@ -232,9 +235,11 @@ async def test_store_survives_bridge_restart(hass: HomeAssistant) -> None:
# Bring up a fresh bridge for the same group on a new channel pair.
_bridge2, main_b, sandbox_b = await _wire(hass, group="built-in")
try:
loaded = await sandbox_b.call("sandbox_v2/store_load", {"key": "persistent"})
loaded = await sandbox_b.call(
"sandbox_v2/store_load", pb.StoreLoad(key="persistent")
)
finally:
await main_b.close()
await sandbox_b.close()
assert loaded == wrapped
assert struct_to_dict(loaded.data) == wrapped
@@ -80,7 +80,7 @@ async def test_inprocess_plugin_round_trips_ping(
data = hass.data[DATA_SANDBOX_V2]
channel = data.channels[DEFAULT_GROUP]
result = await asyncio.wait_for(channel.call("sandbox_v2/ping", None), timeout=2.0)
assert result == {"pong": "sandbox_v2"}
assert result.pong == "sandbox_v2"
async def test_inprocess_plugin_returns_existing_sandbox_on_ensure_started(