diff --git a/smart_control/llm/agents/base_agent.py b/smart_control/llm/agents/base_agent.py new file mode 100644 index 00000000..b074a1a8 --- /dev/null +++ b/smart_control/llm/agents/base_agent.py @@ -0,0 +1,66 @@ +"""Base class for agents that use the control loop.""" + +import abc +from collections.abc import Mapping +from collections.abc import Sequence +import dataclasses +from typing import Any + +from smart_buildings.smart_control.llm.schema import action_context +from smart_buildings.smart_control.proto import smart_control_building_pb2 as building_pb2 +from smart_buildings.smart_control.proto import smart_control_reward_pb2 as reward_pb2 +from smart_buildings.smart_control.utils import serialization + + +@dataclasses.dataclass +class AgentErrorRecord: + """Record of an error produced by an agent. + + Attributes: + error_type: The class name of the exception. + error_message: The string representation of the error. + details: Structured error details (e.g. from Pydantic's ValidationError). + metadata: Extra metadata about the error (e.g. raw response text). + """ + + error_type: str + error_message: str + details: Sequence[Mapping[str, Any]] | None = None + metadata: Mapping[str, Any] | None = None + + @property + def json_metadata(self) -> serialization.SerializableData: + """A JSON-serializable representation of the error record.""" + return serialization.to_serializable(dataclasses.asdict(self)) + + +class BaseControlAgent(abc.ABC): + """An agent that chooses actions based on info from the environment. + + Attributes: + errors: A list of errors recorded by the agent during its operation. + """ + + def __init__(self): + self.errors: list[AgentErrorRecord] = [] + + @abc.abstractmethod + def get_action_context( + self, + observation_response: building_pb2.ObservationResponse | None = None, + reward_info: reward_pb2.RewardInfo | None = None, + ) -> action_context.ActionContext: + """Returns an action context based on the agent's strategy / policy. + + Args: + observation_response: The observation response from the environment. + reward_info: The reward info from the environment. + + Returns: + An action context based on the agent's strategy / policy. + """ + + @property + def json_metadata(self) -> serialization.SerializableData: + """Info about the agent and its setup, to be written to a JSON file.""" + return {"type": self.__class__.__name__} diff --git a/smart_control/llm/agents/base_agent_test.py b/smart_control/llm/agents/base_agent_test.py new file mode 100644 index 00000000..661f6e92 --- /dev/null +++ b/smart_control/llm/agents/base_agent_test.py @@ -0,0 +1,42 @@ +import json + +from absl.testing import absltest +from smart_buildings.smart_control.llm.agents import base_agent + + +class ErrorRecordNestedExceptionsTest(absltest.TestCase): + + def setUp(self): + super().setUp() + nested_error = ValueError('Something went wrong') + + self.record = base_agent.AgentErrorRecord( + error_type='ValidationError', + error_message='Validation failed', + details=[{ + 'loc': ('field',), + 'ctx': {'error': nested_error}, + }], + metadata={'current_step': 4, 'response_txt': 'OOPS'}, + ) + + def test_json_metadata(self): + self.assertEqual( + self.record.json_metadata, + { + 'error_type': 'ValidationError', + 'error_message': 'Validation failed', + 'details': [{ + 'loc': ['field'], + 'ctx': {'error': 'Something went wrong'}, + }], + 'metadata': {'current_step': 4, 'response_txt': 'OOPS'}, + }, + ) + + def test_json_metadata_is_serializable(self): + self.assertIsInstance(json.dumps(self.record.json_metadata), str) + + +if __name__ == '__main__': + absltest.main() diff --git a/smart_control/llm/agents/default_agent.py b/smart_control/llm/agents/default_agent.py new file mode 100644 index 00000000..2e5898e1 --- /dev/null +++ b/smart_control/llm/agents/default_agent.py @@ -0,0 +1,124 @@ +"""Default policy agent. + +This agent employs a fixed strategy that uses the environment's default policy +values for all of its actions. + +This strategy is overly simplistic, but provides a decent foundation for +child classes to inherit from, and can be useful for testing and debugging the +agent control loop. +""" + +from typing import Final + +import numpy as np +from smart_buildings.smart_control.environment import environment +from smart_buildings.smart_control.environment import hybrid_action_environment +from smart_buildings.smart_control.llm.agents import base_agent +from smart_buildings.smart_control.llm.schema import action_context +from smart_buildings.smart_control.llm.schema import output_schema +from smart_buildings.smart_control.proto import smart_control_building_pb2 as building_pb2 +from smart_buildings.smart_control.proto import smart_control_reward_pb2 as reward_pb2 +from smart_buildings.smart_control.utils import serialization + +DEFAULT_JUSTIFICATION: Final[str] = "Default action." +DEFAULT_SETPOINT_JUSTIFICATION: Final[str] = "Default value." + + +class DefaultPolicyAgent(base_agent.BaseControlAgent): + """A control agent that uses the environment's default policy values. + + Attributes: + env: The environment to be controlled. Should be configured with default + policy values. + """ + + def __init__(self, env: environment.Environment, clip: bool = True): + """Initializes the instance. + + Args: + env: The environment to be controlled. Should be configured with + default policy values. + clip: Whether to clip setpoint values to the bounds of the valid range. If + `False`, raises `GuardrailsExceededError`. Otherwise, clips the + setpoint values to the valid range, and logs a record of the error. + Defaults to `True`. + """ + super().__init__() + self._clip = clip + self.env = self._validate_environment(env) + + def _validate_environment( + self, env: environment.Environment + ) -> environment.Environment: + """Ensures the environment has default values.""" + if env.action_names is None: + raise ValueError("Expecting environment to have action names.") + + if env.default_policy_values is None: + raise ValueError("Expecting environment to have default policy values.") + + if len(env.action_names) != len(env.default_policy_values): + raise ValueError( + "Expecting environment to have the same number of action names and" + " default policy values." + ) + + return env + + @property + def json_metadata(self) -> serialization.SerializableData: + """Info about the agent and its setup, to be written to a JSON file.""" + return super().json_metadata | { + "default_policy": { + "action_names": self.env.action_names, + "default_values": self.env.default_action_values, + }, + "clip": self._clip, + } + + @property + def action_context_class(self) -> type[action_context.ActionContext]: + """The action context class to be used by this agent.""" + if isinstance(self.env, hybrid_action_environment.HybridActionEnvironment): + return action_context.HybridActionContext + return action_context.ActionContext + + def get_default_action_context(self) -> action_context.ActionContext: + """Compiles an action context using the environment's default values.""" + + setpoints = [] + for action_name, normalized_value in zip( + self.env.action_names, self.env.default_action_values + ): + device_id, setpoint_name = self.env.id_map.inv[action_name] + normalizer = self.env.action_normalizers.get(setpoint_name) + if normalizer is None: + raise ValueError(f"No normalizer found for setpoint: {setpoint_name}") + + native_value = normalizer.setpoint_value(np.array(normalized_value)) + setpoints.append( + output_schema.DeviceSetpoint( + device_id=device_id, + setpoint_name=setpoint_name, + setpoint_value=native_value, + justification=DEFAULT_SETPOINT_JUSTIFICATION, + ) + ) + + return self.action_context_class( + env=self.env, + clip=self._clip, + timestamp=str(self.env.current_local_timestamp), + justification=DEFAULT_JUSTIFICATION, + validity_interval=self.env.time_step_mins, + setpoints=setpoints, + ) + + def get_action_context( + self, + observation_response: building_pb2.ObservationResponse | None = None, + reward_info: reward_pb2.RewardInfo | None = None, + ) -> action_context.ActionContext: + """The action context to be used within the agent control loop.""" + del observation_response, reward_info # Unused in this implementation. + return self.get_default_action_context() diff --git a/smart_control/llm/agents/default_agent_test.py b/smart_control/llm/agents/default_agent_test.py new file mode 100644 index 00000000..66a69cbb --- /dev/null +++ b/smart_control/llm/agents/default_agent_test.py @@ -0,0 +1,202 @@ +import json +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +import pandas as pd +from smart_buildings.smart_control.environment import conftest as env_conftest +from smart_buildings.smart_control.environment import environment +from smart_buildings.smart_control.environment import hybrid_action_environment +from smart_buildings.smart_control.llm.agents import default_agent +from smart_buildings.smart_control.llm.schema import action_context + + +class AgentEnvironmentValidationsTest(parameterized.TestCase): + + AGENT_CLASS = default_agent.DefaultPolicyAgent + + def setUp(self): + super().setUp() + self.env = mock.create_autospec(environment.Environment, instance=True) + self.env.action_names = list(env_conftest.DEFAULT_ACTIONS.keys()) + self.env.default_policy_values = list(env_conftest.DEFAULT_ACTIONS.values()) + + def test_valid_environment(self): + agent = self.AGENT_CLASS(self.env) + self.assertIsInstance(agent, self.AGENT_CLASS) + + def test_validate_action_names(self): + self.env.action_names = None + + with self.assertRaisesRegex( + ValueError, "Expecting environment to have action names." + ): + self.AGENT_CLASS(self.env) + + def test_validate_default_values(self): + self.env.default_policy_values = None + + with self.assertRaisesRegex( + ValueError, "Expecting environment to have default policy values." + ): + self.AGENT_CLASS(self.env) + + def test_validate_number_of_action_names_and_default_values(self): + self.env.action_names = self.env.action_names[1:] + + with self.assertRaisesRegex( + ValueError, + "Expecting environment to have the same number of action names and" + " default policy values.", + ): + self.AGENT_CLASS(self.env) + + +class DefaultAgentTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.env = self._create_environment() + self.agent = self._create_agent(self.env) + + def _create_environment( + self, start_timestamp: pd.Timestamp | None = None + ) -> environment.Environment: + return env_conftest.create_environment( + layout=env_conftest.DEMO_LAYOUT, + default_actions=env_conftest.DEFAULT_ACTIONS, + start_timestamp=start_timestamp, + ) + + def _create_agent( + self, env: environment.Environment + ) -> default_agent.DefaultPolicyAgent: + return default_agent.DefaultPolicyAgent(env=env) + + def test_initialization(self): + self.assertIsInstance(self.agent, default_agent.DefaultPolicyAgent) + + def test_environment(self): + self.assertIsInstance(self.agent.env, environment.Environment) + + def test_json_metadata(self): + self.assertEqual( + self.agent.json_metadata, + { + "type": "DefaultPolicyAgent", + "default_policy": { + "action_names": self.env.action_names, + "default_values": self.env.default_action_values, + }, + "clip": True, + }, + ) + + def test_json_metadata_is_serializable(self): + self.assertEqual( + self.agent.json_metadata, + json.loads(json.dumps(self.agent.json_metadata, indent=2)), + ) + + def test_default_action_context(self): + ctx = self.agent.get_default_action_context() + self.assertIsInstance(ctx, action_context.ActionContext) + + with self.subTest(name="timestamp"): + self.assertEqual(ctx.timestamp, str(self.env.current_local_timestamp)) + + with self.subTest(name="justification"): + self.assertEqual(ctx.justification, default_agent.DEFAULT_JUSTIFICATION) + + with self.subTest(name="validity_interval"): + self.assertEqual(ctx.validity_interval, self.env.time_step_mins) + + with self.subTest(name="setpoints"): + self.assertLen(ctx.setpoints, len(self.env.action_names)) + + # Setpoint and device names should match the env's action names: + names = [(sp.device_id, sp.setpoint_name) for sp in ctx.sorted_setpoints] + self.assertEqual( + names, + [ + ("air_handler_1", "supply_air_heating_temperature_setpoint"), + ("boiler_1", "supply_water_setpoint"), + ("air_handler_2", "supply_air_heating_temperature_setpoint"), + ], + ) + + # Setpoint values should be native versions of the env's default values: + setpoint_values = [sp.setpoint_value for sp in ctx.setpoints] + self.assertSequenceAlmostEqual(setpoint_values, [290.0, 310.0, 290.0]) + + def test_get_action_context(self): + ctx = self.agent.get_action_context() + self.assertIsInstance(ctx, action_context.ActionContext) + self.assertEqual(ctx, self.agent.get_default_action_context()) + + +class DefaultHybridActionAgentTest(DefaultAgentTest): + + def _create_environment( + self, start_timestamp: pd.Timestamp | None = None + ) -> hybrid_action_environment.HybridActionEnvironment: + return env_conftest.create_hybrid_action_environment( + layout=env_conftest.DEMO_LAYOUT, + default_actions=env_conftest.DEFAULT_HYBRID_ACTIONS, + start_timestamp=start_timestamp, + ) + + def test_environment(self): + self.assertIsInstance( + self.agent.env, hybrid_action_environment.HybridActionEnvironment + ) + + def test_default_action_context(self): + ctx = self.agent.get_default_action_context() + self.assertIsInstance(ctx, action_context.ActionContext) + + with self.subTest(name="timestamp"): + self.assertEqual(ctx.timestamp, str(self.env.current_local_timestamp)) + + with self.subTest(name="justification"): + self.assertEqual(ctx.justification, default_agent.DEFAULT_JUSTIFICATION) + + with self.subTest(name="validity_interval"): + self.assertEqual(ctx.validity_interval, self.env.time_step_mins) + + with self.subTest(name="setpoints"): + self.assertLen(ctx.setpoints, len(self.env.action_names)) + self.assertSequenceAlmostEqual( + ctx.get_action_values(), self.env.default_action_values + ) + + # Setpoint and device names should match the env's action names: + names = [(sp.device_id, sp.setpoint_name) for sp in ctx.sorted_setpoints] + with self.subTest(name="setpoint_names"): + self.assertEqual( + names, + [ + ("air_handler_1", "supply_air_heating_temperature_setpoint"), + ("air_handler_1", "supervisor_run_command"), + ("boiler_1", "supply_water_setpoint"), + ("boiler_1", "supervisor_run_command"), + ("air_handler_2", "supply_air_heating_temperature_setpoint"), + ("air_handler_2", "supervisor_run_command"), + ], + ) + + # Setpoint values should be native versions of the env's default values: + setpoint_values = [sp.setpoint_value for sp in ctx.setpoints] + with self.subTest(name="setpoint_values"): + self.assertSequenceAlmostEqual( + setpoint_values, [290.0, 0, 310.0, 0, 290.0, 0] + ) + + def test_get_action_context(self): + ctx = self.agent.get_action_context() + self.assertIsInstance(ctx, action_context.ActionContext) + self.assertEqual(ctx, self.agent.get_default_action_context()) + + +if __name__ == "__main__": + absltest.main() diff --git a/smart_control/llm/agents/llm_agent.py b/smart_control/llm/agents/llm_agent.py new file mode 100644 index 00000000..53cfe4f4 --- /dev/null +++ b/smart_control/llm/agents/llm_agent.py @@ -0,0 +1,389 @@ +"""LLM agent. + +This agent uses a large language model (LLM) to determine its actions. The agent +can use any LLM service that implements the `BaseLLMService` interface. + +First, the agent gets current building conditions from the environment, and +passes this information to a promptmaker, which is responsible for dynamically +constructing a prompt. + +The agent then passes the prompt to the LLM using the configured LLM service. +The agent then validates the LLM's response to ensure it is JSON-formatted and +adheres to the specified "action" output schema. If the action is valid, the +agent stores this information for future reference, and returns the action to +the control loop. + +If the LLM isn't able to produce a valid response, the agent keeps a record of +the error(s), and tries again (with exponential backoff), until it receives a +valid response or reaches the maximum number of tries. + +If the agent doesn't receive a valid action after trying the maximum number of +times, it logs a record of this max retries exceeded error, and gracefully uses +a fallback action: + + + If a previous valid action is available, the agent uses a modified version + of its most recent action, except it uses the shortest possible validity + interval, to give the agent a chance to get a new action at the next + available opportunity. + + If a previous action isn't available, the agent falls back to using the + environment's normally scheduled default action. + +Since this agent inherits from the Schedule Policy Agent to determine the +normally scheduled action, it should be used in conjunction with a hybrid action +environment. +""" + +import re +from typing import Any, override + +from absl import logging +import backoff +from smart_buildings.smart_control.environment import hybrid_action_environment +from smart_buildings.smart_control.llm.agents import base_agent +from smart_buildings.smart_control.llm.agents import schedule_agent +from smart_buildings.smart_control.llm.prompts import promptmaker as pm +from smart_buildings.smart_control.llm.schema import action_context +from smart_buildings.smart_control.llm.services import llm_service as llm +from smart_buildings.smart_control.llm.utils import schedule_tool as schedule_lib +from smart_buildings.smart_control.proto import smart_control_building_pb2 as building_pb2 +from smart_buildings.smart_control.proto import smart_control_reward_pb2 as reward_pb2 +from smart_buildings.smart_control.utils import serialization +from smart_buildings.smart_control.utils import temperature_conversion as temp + +_MARKDOWN_CODE_BLOCK_RE = re.compile(r'```(?:json)?\s*(.*?)\s*```', re.DOTALL) + + +def parse_response_text(txt: str | None) -> str: + """Parses and cleans the raw text response from the LLM. + + The response text is expected to be a JSON-formatted string. In practice + we often see the response wrapped in a JSON-formatted markdown code block, + even when we instruct the LLM to not do that. Perhaps it sees the markdown + formatting in the prompt and thinks it should use markdown formatting in the + response as well. So this method will try to strip the markdown code block + formatting to ensure the text is valid JSON. + + Args: + txt: The raw response text returned by the LLM. + + Returns: + The LLM's textual response as a valid JSON-formatted string. + + Raises: + ValueError: If the response text is not a string. + """ + # FYI: When using the Gemini API, sometimes the response text is None. + # For example, in the case of a max tokens error. + if not isinstance(txt, str): + raise ValueError('Expecting a string response') + + # If the response is wrapped in a markdown code block, extract it: + match = _MARKDOWN_CODE_BLOCK_RE.search(txt) + if match: + return match.group(1).strip() + + # Otherwise, fallback to stripping standard markdown fences and whitespace: + return txt.replace('```json', '').replace('```', '').strip() + + +class MaxRetriesExceededError(Exception): + """Maximum number of retries met or exceeded.""" + + pass + + +class LLMAgent(schedule_agent.SchedulePolicyAgent): + """LLM agent. + + Attributes: + llm_service: The LLM service to use for generating responses. + output_schema_class: The Pydantic model class used to validate and parse the + LLM's JSON response. + temp_display_unit: The temperature unit used for displaying temperatures in + the LLM's response justifications. + promptmaker: The promptmaker instance used to construct the LLM prompt. + max_tries: The maximum number of times to attempt calling the LLM if parsing + or validation errors occur. After this limit is reached, the agent will + fallback to using the scheduled action context. + """ + + def __init__( + self, + *, + env: hybrid_action_environment.HybridActionEnvironment, + llm_service: llm.BaseLLMService, + promptmaker: pm.Promptmaker | None = None, + promptmaker_class: type[pm.Promptmaker] | None = None, + max_tries: int = 5, + clip: bool = True, + override_discrete_defaults: bool = True, + schedule_tool: schedule_lib.BuildingScheduleTool | None = None, + ): + """Initializes the instance. + + Pass either a promptmaker instance or a promptmaker class. The promptmaker + class will only be used if a promptmaker instance is not provided. The + promptmaker class is a convenience argument because we usually want to use + the same promptmaker arguments, and will just change the promptmaker class + to represent different buildings or custom validity intervals. + + Args: + env: The environment in which the agent will operate. + llm_service: The LLM service to use for generating responses. + promptmaker: The promptmaker instance used to construct the LLM prompt. + promptmaker_class: The promptmaker class used to construct the LLM prompt. + If a promptmaker instance is provided, this argument will be ignored, + otherwise a promptmaker instance will be created using this class, using + reasonable default arguments. + max_tries: The maximum number of times to attempt calling the LLM if + parsing or validation errors occur. + clip: Whether to clip the generated setpoints to be within the valid + ranges defined by the environment. + override_discrete_defaults: Whether to override discrete defaults when + getting the scheduled action context. + schedule_tool: Optionally provide a BuildingScheduleTool instance. + Otherwise, a schedule tool will be constructed using the agent's + environment and default schedule tool arguments. + """ + super().__init__( + env=env, + schedule_tool=schedule_tool, + clip=clip, + override_discrete_defaults=override_discrete_defaults, + ) + + self.llm_service = llm_service + self.promptmaker = self._setup_promptmaker(promptmaker, promptmaker_class) + self.output_schema_class = self.promptmaker.output_schema_class + self.temp_display_unit = self.promptmaker.temp_display_unit + self.max_tries = max_tries + + self._last_attempt_response_text: str | None = None + self._last_valid_llm_action: action_context.ActionContext | None = None + + # Wrap `_attempt_get_action_context` to retry with exponential backoff. + self._retry_attempt_get_action_context_with_backoff = backoff.on_exception( + wait_gen=backoff.expo, + exception=Exception, + max_tries=self.max_tries, + jitter=backoff.full_jitter, + on_backoff=self._on_backoff, + on_giveup=self._on_giveup, + )( + self._attempt_get_action_context # the method being wrapped + ) + + def _setup_promptmaker( + self, + promptmaker: pm.Promptmaker | None = None, + promptmaker_class: type[pm.Promptmaker] | None = None, + ) -> pm.Promptmaker: + """Sets up the promptmaker instance.""" + if (promptmaker is None and promptmaker_class is None) or ( + promptmaker is not None and promptmaker_class is not None + ): + raise ValueError( + 'Either a promptmaker instance or class must be provided, not both.' + ) + + return promptmaker or promptmaker_class( + env=self.env, + observation_response=None, + reward_info=None, + lazy_init_protos=True, + output_schema_class=action_context.HybridActionContext, + temp_display_unit=temp.TempUnit.FAHRENHEIT, + include_weights=True, + ) + + @override + @property + def json_metadata(self) -> serialization.SerializableData: + return super().json_metadata | { + 'llm_service': self.llm_service.json_metadata, + 'promptmaker': self.promptmaker.json_metadata, + 'output_schema': {'type': self.output_schema_class.__name__}, + 'temp_display_unit': self.temp_display_unit.value, + 'max_tries': self.max_tries, + } + + # PROMPT + + def make_prompt( + self, + observation_response: building_pb2.ObservationResponse, + reward_info: reward_pb2.RewardInfo, + ) -> str: + """Creates a prompt, using the provided promptmaker class. + + Args: + observation_response: The observation response from the environment. + reward_info: The reward info from the environment. + + Returns: + The prompt to be sent to the LLM. + """ + self.promptmaker.set_protos( + observation_response=observation_response, reward_info=reward_info + ) + return self.promptmaker.prompt + + # RESPONSE VALIDATION + + def validate_action_context(self, txt: str) -> action_context.ActionContext: + """Ensures the response text is in the expected JSON format. + + Args: + txt: The response text to validate. + + Raises: + pydantic.ValidationError: If the response text is not valid JSON. + + Returns: + The validated action context object. + """ + if issubclass(self.output_schema_class, action_context.ActionContext): + return self.output_schema_class.from_json( + txt=txt, env=self.env, clip=self._clip + ) + + action = self.output_schema_class.model_validate_json(txt) + return self.action_context_class( + env=self.env, clip=self._clip, **action.model_dump() + ) + + # ACTION + + def _attempt_get_action_context( + self, + prompt: str, + ) -> action_context.ActionContext: + """Attempts to get a valid action from the LLM. + + Clears and resets the last response text that has been received from the + LLM. + + FYI: When using the Gemini API, sometimes the response text is None. + For example, in the case of a max tokens error. + + Args: + prompt: The prompt to send to the LLM service. + + Raises: + ValueError: If the LLM service returns None. + JSONDecodeError: If the response text is string but not valid JSON. + pydantic.ValidationError: If the JSON doesn't adhere to the output schema. + + Returns: + A validated action context object. + """ + self._last_attempt_response_text = None + response_text = self.llm_service.get_response(prompt) + self._last_attempt_response_text = response_text + + if response_text is None: + raise ValueError('LLM service returned None') + + action = self.validate_action_context(parse_response_text(response_text)) + self._last_valid_llm_action = action + return action + + def _record_backoff_error( + self, + error_details: dict[str, Any], + ) -> None: + """Consolidates logic for recording error details returned by backoff. + + When using the @backoff.on_exception decorator, the details dictionary + passed to the on_backoff and on_giveup handler functions contains the + following keys: + + - target: The decorated function that is being retried. + - args: The positional arguments passed to the target function. + - kwargs: The keyword arguments passed to the target function. + - tries: The number of attempts made so far. + - elapsed: The time in seconds elapsed since the first attempt. + - exception: The exception instance that was caught and triggered the + backoff or giveup. + + Specifically for the on_backoff handler, the details dictionary also + includes: + + - wait: The calculated number of seconds to wait before the next retry + attempt. + + Args: + error_details: The error details returned by backoff. + """ + exception = error_details['exception'] + nested_errors = exception.errors() if hasattr(exception, 'errors') else None + + self.errors.append( + base_agent.AgentErrorRecord( + error_type=exception.__class__.__name__, + error_message=repr(exception), + details=nested_errors, + metadata={ + 'tries': error_details.get('tries'), + 'elapsed': error_details.get('elapsed'), + 'wait': error_details.get('wait'), + 'response_text': self._last_attempt_response_text, + }, + ) + ) + + def _on_backoff(self, details: dict[str, Any]) -> None: + """Records an error that occurred during a backoff retry.""" + logging.debug('ON BACKOFF: %r', details) + self._record_backoff_error(details) + + def _on_giveup(self, details: dict[str, Any]) -> None: + """Records final error and exhaustion of retries.""" + + # Record the final specific error that caused the giveup. + logging.debug('ON GIVEUP: %r', details) + self._record_backoff_error(details) + + # Record that max retries were exceeded. + exhaustion_error = base_agent.AgentErrorRecord( + error_type=MaxRetriesExceededError.__name__, + error_message=f'Max tries ({self.max_tries}) exceeded.', + metadata={}, + ) + self.errors.append(exhaustion_error) + + def get_action_context( + self, + observation_response: building_pb2.ObservationResponse | None = None, + reward_info: reward_pb2.RewardInfo | None = None, + ) -> action_context.ActionContext: + """Returns the action context to be used within the agent control loop. + + Args: + observation_response: The observation response from the environment. + reward_info: The reward info from the environment. + + Returns: + The action context to be used within the agent control loop. + """ + prompt = self.make_prompt(observation_response, reward_info) + try: + return self._retry_attempt_get_action_context_with_backoff(prompt) + except Exception: # pylint: disable=broad-except + # All retry attempts failed, and on_giveup has recorded the errors. + if self._last_valid_llm_action is not None: + logging.exception( + 'LLM MAX TRIES EXCEEDED. FALLING BACK TO PREVIOUS LLM ACTION...', + ) + return self._last_valid_llm_action.model_copy( + update={ + 'validity_interval': self.env.time_step_mins, + 'justification': 'Previous LLM action (max retries exceeded)', + } + ) + + logging.exception( + 'LLM MAX TRIES EXCEEDED. NO PREVIOUS LLM ACTION AVAILABLE. FALLING' + ' BACK TO SCHEDULED ACTION...', + ) + return self.get_scheduled_action_context() diff --git a/smart_control/llm/agents/llm_agent_test.py b/smart_control/llm/agents/llm_agent_test.py new file mode 100644 index 00000000..0c0277a7 --- /dev/null +++ b/smart_control/llm/agents/llm_agent_test.py @@ -0,0 +1,562 @@ +import json +import time +from typing import get_args +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +import pydantic +from smart_buildings.smart_control.environment import conftest as env_conftest +from smart_buildings.smart_control.environment import hybrid_action_environment +from smart_buildings.smart_control.llm.agents import llm_agent +from smart_buildings.smart_control.llm.agents import schedule_agent_test +from smart_buildings.smart_control.llm.prompts import promptmaker as pm +from smart_buildings.smart_control.llm.schema import action_context +from smart_buildings.smart_control.llm.schema import conftest as schema_conftest +from smart_buildings.smart_control.llm.schema import output_schema +from smart_buildings.smart_control.llm.services import llm_service +from smart_buildings.smart_control.llm.utils import schedule_tool as schedule_lib +from smart_buildings.smart_control.llm.utils import schedule_tool_test + + +class TextParsingTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='json_fence', input_text='```json\n{"a": 1}\n```'), + dict(testcase_name='plain_fence', input_text='```\n{"a": 1}\n```'), + dict(testcase_name='no_fences', input_text='{"a": 1}'), + dict( + testcase_name='text_around_fences', + input_text='Text\n```json\n{"a": 1}\n```\nMore text', + ), + ) + def test_parse_response_text_variants(self, input_text): + expected_text = json.dumps({'a': 1}) + self.assertEqual(llm_agent.parse_response_text(input_text), expected_text) + + @parameterized.named_parameters( + dict( + testcase_name='valid_json_valid_schema', + input_text=schema_conftest.create_hybrid_action_response(), + ), + dict( + testcase_name='valid_json_empty_setpoints', + input_text=schema_conftest.create_hybrid_action_response( + empty_setpoints=True + ), + ), + dict( + testcase_name='valid_json_missing_setpoint', + input_text=schema_conftest.create_hybrid_action_response( + missing_setpoint=True + ), + ), + dict( + testcase_name='valid_json_missing_field', + input_text=schema_conftest.create_hybrid_action_response( + missing_field=True + ), + ), + ) + def test_parse_response_text_valid_json_invalid_schema(self, input_text): + self.assertEqual(llm_agent.parse_response_text(input_text), input_text) + + def test_parse_response_text_non_string_input(self): + with self.assertRaisesRegex(ValueError, 'Expecting a string response'): + llm_agent.parse_response_text(None) + + +class LLMAgentTest( + schedule_agent_test.ScheduleHybridActionAgentTest, parameterized.TestCase +): + + env: hybrid_action_environment.HybridActionEnvironment + agent: llm_agent.LLMAgent + mock_llm_service: mock.MagicMock + schedule_tool: schedule_lib.BuildingScheduleTool + + def _create_agent( + self, env: hybrid_action_environment.HybridActionEnvironment + ) -> llm_agent.LLMAgent: + self.mock_llm_service = mock.create_autospec( + llm_service.BaseLLMService, instance=True, spec_set=True + ) + self.mock_llm_service.json_metadata = {'type': 'MockLLMService'} + self.schedule_tool = schedule_lib.BuildingScheduleTool(env=env) + return llm_agent.LLMAgent( + env=env, + llm_service=self.mock_llm_service, + promptmaker_class=pm.Promptmaker, + schedule_tool=self.schedule_tool, + ) + + def test_json_metadata(self): + self.assertEqual( + self.agent.json_metadata, + { + 'type': 'LLMAgent', + 'default_policy': { + 'action_names': [ + 'air_handler_1_supply_air_heating_temperature_setpoint', + 'air_handler_1_supervisor_run_command', + 'boiler_1_supply_water_setpoint', + 'boiler_1_supervisor_run_command', + 'air_handler_2_supply_air_heating_temperature_setpoint', + 'air_handler_2_supervisor_run_command', + ], + 'default_values': [0.0, -1.0, -1.0, -1.0, 0.0, -1.0], + }, + 'override_discrete_defaults': True, + 'schedule_policy': schedule_tool_test.SCHEDULE_METADATA, + 'llm_service': {'type': 'MockLLMService'}, + 'promptmaker': { + 'type': 'Promptmaker', + 'output_schema_class': 'HybridActionContext', + 'include_weights': True, + 'occupancy_mode_min': 10, + 'temp_display_unit': 'Fahrenheit', + 'building_info': { + 'stories': 'two', + 'sqft': 96_000, + 'location': 'Mountain View, California', + 'name': 'SB-1', + }, + }, + 'output_schema': {'type': 'HybridActionContext'}, + 'temp_display_unit': 'Fahrenheit', + 'max_tries': 5, + 'clip': True, + }, + ) + + # PROMPT + + def test_make_prompt(self): + observation_response = self.env.get_observation_response() + reward_info = self.env.get_reward_info() + prompt = self.agent.make_prompt(observation_response, reward_info) + self.assertIsInstance(prompt, str) + self.assertNotEmpty(prompt) + + # RESPONSE VALIDATION + + def test_validate_action_context(self): + valid_response_json = schema_conftest.create_hybrid_action_response() + ctx = self.agent.validate_action_context(valid_response_json) + self.assertIsInstance(ctx, action_context.ActionContext) + + def test_validate_action_context_invalid_json(self): + with self.assertRaisesRegex(json.JSONDecodeError, 'Expecting value'): + self.agent.validate_action_context('oops, invalid json') + + def test_validate_action_context_invalid_schema(self): + with self.assertRaisesRegex( + pydantic.ValidationError, r'validity_interval\n\s+Field required' + ): + self.agent.validate_action_context('{"valid": "json"}') + + def test_validate_action_context_missing_setpoint(self): + invalid_schema_json = schema_conftest.create_hybrid_action_response( + missing_setpoint=True + ) + with self.assertRaisesRegex( + pydantic.ValidationError, r'missing from the schema' + ): + self.agent.validate_action_context(invalid_schema_json) + + def test_validate_action_context_guardrails_exceeded(self): + invalid_schema_json = schema_conftest.create_hybrid_action_response( + ahu_1_run_command=1, + ahu_1_supply_air_temp=99999.0, + hws_run_command=1, + hws_supply_water_temp=99999.0, + ) + ctx = self.agent.validate_action_context(invalid_schema_json) + self.assertEqual( + ctx.guardrails_exceeded, + [ + action_context.GuardrailsExceededRecord( + device_id='air_handler_1', + setpoint_name='supply_air_heating_temperature_setpoint', + requested_value=99999.0, + setpoint_range=(285.0, 295.0), + clipped_value=295.0, + ), + action_context.GuardrailsExceededRecord( + device_id='boiler_1', + setpoint_name='supply_water_setpoint', + requested_value=99999.0, + setpoint_range=(310.0, 350.0), + clipped_value=350.0, + ), + ], + ) + + # GET ACTION + + def test_get_action_context_success(self): + valid_response_json = schema_conftest.create_hybrid_action_response() + self.mock_llm_service.get_response.return_value = valid_response_json + + ctx = self.agent.get_action_context() + self.assertIsInstance(ctx, action_context.ActionContext) + + with self.subTest(name='calls_the_llm'): + self.mock_llm_service.get_response.assert_called_once() + + with self.subTest(name='no_errors'): + self.assertEmpty(self.agent.errors) + + def test_get_action_context_fenced_response(self): + json_data = schema_conftest.create_hybrid_action_response() + fenced_response = ( + f'Here is the JSON you requested:\n```json\n{json_data}\n```\nI hope' + ' that helps!' + ) + self.mock_llm_service.get_response.return_value = fenced_response + + ctx = self.agent.get_action_context() + self.assertIsInstance(ctx, action_context.ActionContext) + + with self.subTest(name='calls_the_llm'): + self.mock_llm_service.get_response.assert_called_once() + + with self.subTest(name='no_errors'): + self.assertEmpty(self.agent.errors) + + +class AlternativeSchemaTestBase(absltest.TestCase): + + def setUp(self): + super().setUp() + self.env = env_conftest.create_hybrid_action_environment( + layout=env_conftest.DEMO_LAYOUT, + default_actions=env_conftest.DEFAULT_HYBRID_ACTIONS, + ) + self.mock_llm_service = mock.create_autospec( + llm_service.BaseLLMService, instance=True, spec_set=True + ) + self.mock_llm_service.json_metadata = {'type': 'MockLLMService'} + valid_json = schema_conftest.create_hybrid_action_response() + self.mock_llm_service.get_response.return_value = valid_json + + +class LLMAgentSetpointsActionSchemaTest(AlternativeSchemaTestBase): + + def setUp(self): + super().setUp() + self.agent = llm_agent.LLMAgent( + env=self.env, + llm_service=self.mock_llm_service, + promptmaker_class=pm.Promptmaker, + ) + + def test_get_action_context(self): + ctx = self.agent.get_action_context() + self.assertIsInstance(ctx, action_context.ActionContext) + + +class LLMAgentCustomValidityIntervalTest(AlternativeSchemaTestBase): + + def setUp(self): + self.custom_intervals = [15, 30, 60, 90] + super().setUp() + self.output_schema_class = action_context.create_action_context_model( + custom_intervals=self.custom_intervals + ) + self.promptmaker = pm.Promptmaker( + env=self.env, + output_schema_class=self.output_schema_class, + lazy_init_protos=True, + ) + self.agent = llm_agent.LLMAgent( + env=self.env, + llm_service=self.mock_llm_service, + promptmaker=self.promptmaker, + ) + + def test_custom_validity_interval(self): + self.assertEqual( + get_args( + self.agent.output_schema_class.__annotations__['validity_interval'] + ), + tuple(self.custom_intervals), + ) + + with self.subTest(name='in_prompt'): + prompt = self.agent.make_prompt( + self.env.get_observation_response(), self.env.get_reward_info() + ) + self.assertIn(str(self.custom_intervals), prompt) + + def test_get_action_context(self): + valid_json = schema_conftest.create_hybrid_action_response() + self.mock_llm_service.get_response.return_value = valid_json + + ctx = self.agent.get_action_context() + self.assertIsInstance(ctx, action_context.ActionContext) + + +class LLMAgentPromptmakerValidationTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.mock_llm_service = mock.create_autospec( + llm_service.BaseLLMService, instance=True, spec_set=True + ) + self.env = env_conftest.create_hybrid_action_environment( + layout=env_conftest.DEMO_LAYOUT, + default_actions=env_conftest.DEFAULT_HYBRID_ACTIONS, + ) + self.promptmaker_instance = pm.Promptmaker( + env=self.env, lazy_init_protos=True + ) + + def test_init_with_promptmaker_instance_succeeds(self): + agent = llm_agent.LLMAgent( + env=self.env, + llm_service=self.mock_llm_service, + promptmaker=self.promptmaker_instance, + ) + self.assertIsInstance(agent.promptmaker, pm.Promptmaker) + self.assertEqual(agent.promptmaker, self.promptmaker_instance) + + def test_init_with_promptmaker_class_succeeds(self): + agent = llm_agent.LLMAgent( + env=self.env, + llm_service=self.mock_llm_service, + promptmaker_class=pm.Promptmaker, + ) + self.assertIsInstance(agent.promptmaker, pm.Promptmaker) + + with self.subTest(name='uses_default_arguments'): + self.assertEqual( + agent.promptmaker.output_schema_class, + action_context.HybridActionContext, + ) + self.assertTrue(agent.promptmaker.lazy_init_protos) + self.assertTrue(agent.promptmaker.include_weights) + + def test_init_raises_error_with_neither_promptmaker_nor_class(self): + with self.assertRaisesRegex( + ValueError, + 'Either a promptmaker instance or class must be provided, not both.', + ): + llm_agent.LLMAgent( + env=self.env, + llm_service=self.mock_llm_service, + promptmaker=None, + promptmaker_class=None, + ) + + def test_init_raises_error_with_both_promptmaker_and_class(self): + with self.assertRaisesRegex( + ValueError, + 'Either a promptmaker instance or class must be provided, not both.', + ): + llm_agent.LLMAgent( + env=self.env, + llm_service=self.mock_llm_service, + promptmaker=self.promptmaker_instance, + promptmaker_class=pm.Promptmaker, + ) + + +class LLMAgentRetryTest(LLMAgentTest): + """Tests for retry and backoff behavior. + + NOTE: We need to mock time.sleep because of the retry logic used in LLMAgent, + which uses the backoff library to handle retries when the LLM fails to return + a valid response. By default, backoff attempts to wait between retries by + calling time.sleep. If we do not mock time.sleep, the unit tests will actually + pause and wait, but instead we are mocking it to return immediately. + """ + + @mock.patch.object(time, 'sleep', return_value=None) + def test_retry_succeeds_after_failures(self, mock_sleep): + valid_response_json = schema_conftest.create_hybrid_action_response() + self.mock_llm_service.get_response.side_effect = [ + RuntimeError('Service Fail 1'), # on_backoff records this + 'oops invalid json', # on_backoff records this + valid_response_json, + ] + agent = llm_agent.LLMAgent( + env=self.env, + llm_service=self.mock_llm_service, + promptmaker_class=pm.Promptmaker, + max_tries=3, + ) + + ctx = agent.get_action_context() + + # The agent should succeed on the 3rd attempt. + self.assertIsInstance(ctx, action_context.ActionContext) + self.assertEqual(self.mock_llm_service.get_response.call_count, 3) + + # 2 errors should be recorded by on_backoff. + self.assertLen(agent.errors, 2) + self.assertEqual(agent.errors[0].error_type, 'RuntimeError') + self.assertEqual(agent.errors[0].metadata['tries'], 1) + self.assertEqual(agent.errors[1].error_type, 'JSONDecodeError') + self.assertEqual(agent.errors[1].metadata['tries'], 2) + + @mock.patch.object(time, 'sleep', return_value=None) + def test_exceeds_max_retries_and_falls_back(self, mock_sleep): + self.mock_llm_service.get_response.side_effect = RuntimeError( + 'Always failing' + ) + agent = llm_agent.LLMAgent( + env=self.env, + llm_service=self.mock_llm_service, + promptmaker_class=pm.Promptmaker, + max_tries=2, + ) + # Mock scheduled action to confirm fallback. + scheduled_ctx = mock.MagicMock() + with mock.patch.object( + agent, + 'get_scheduled_action_context', + return_value=scheduled_ctx, + autospec=True, + ) as mock_get_scheduled: + ctx = agent.get_action_context() + + # Check that fallback occurred. + self.assertEqual(ctx, scheduled_ctx) + mock_get_scheduled.assert_called_once() + + # on_backoff is called for try 1, on_giveup for try 2. + # Total calls: 2. + # Total errors: + # 1st recorded by _on_backoff + # 2nd recorded by _on_giveup + # MaxRetriesExceededError recorded by _on_giveup + self.assertEqual(self.mock_llm_service.get_response.call_count, 2) + self.assertLen(agent.errors, 3) + + # Error from on_backoff + self.assertEqual(agent.errors[0].error_type, 'RuntimeError') + self.assertEqual(agent.errors[0].metadata['tries'], 1) + self.assertIsNotNone(agent.errors[0].metadata['wait']) + + # Error from on_giveup + self.assertEqual(agent.errors[1].error_type, 'RuntimeError') + self.assertEqual(agent.errors[1].metadata['tries'], 2) + self.assertIsNone(agent.errors[1].metadata['wait']) + + # Exhaustion error from on_giveup + self.assertEqual(agent.errors[2].error_type, 'MaxRetriesExceededError') + + @mock.patch.object(time, 'sleep', return_value=None) + def test_pydantic_error_details_on_giveup(self, mock_sleep): + invalid_schema_json = json.dumps({'validity_interval': 15, 'setpoints': []}) + self.mock_llm_service.get_response.return_value = invalid_schema_json + agent = llm_agent.LLMAgent( + env=self.env, + llm_service=self.mock_llm_service, + promptmaker_class=pm.Promptmaker, + max_tries=1, + ) + + agent.get_action_context() + + # on_giveup is called for try 1 because max_tries=1. + self.assertLen(agent.errors, 2) + err = agent.errors[0] + self.assertEqual(err.error_type, 'ValidationError') + self.assertEqual(err.metadata['tries'], 1) + + # Check that pydantic error details were recorded. + self.assertIsInstance(err.details, list) + self.assertNotEmpty(err.details) + self.assertEqual(err.details[0]['type'], 'missing') + self.assertEqual(err.details[0]['loc'], ('timestamp',)) + + @mock.patch.object(llm_agent.logging, 'exception') + @mock.patch.object(time, 'sleep', return_value=None) + def test_exceeds_max_retries_falls_back_to_previous_action( + self, mock_sleep, mock_exception + ): + valid_response_json = schema_conftest.create_hybrid_action_response( + validity_interval=15 + ) + self.mock_llm_service.get_response.side_effect = [ + valid_response_json, # First call succeeds + RuntimeError('Always failing'), # Second call fails + RuntimeError('Always failing'), # Third call fails (exceeds max_tries) + ] + agent = llm_agent.LLMAgent( + env=self.env, + llm_service=self.mock_llm_service, + promptmaker_class=pm.Promptmaker, + max_tries=2, + ) + + # 1. First call to get_action_context succeeds. + ctx1 = agent.get_action_context() + self.assertEqual(ctx1.validity_interval, 15) + + # 2. Second call to get_action_context fails all retries, should fallback to + # the previous successful action with the environment's time step interval. + ctx2 = agent.get_action_context() + mock_exception.assert_called_once_with( + 'LLM MAX TRIES EXCEEDED. FALLING BACK TO PREVIOUS LLM ACTION...' + ) + + self.assertEqual(ctx2.validity_interval, agent.env.time_step_mins) + self.assertEqual(ctx2.setpoints, ctx1.setpoints) + self.assertEqual( + ctx2.justification, 'Previous LLM action (max retries exceeded)' + ) + self.assertEqual(ctx2.timestamp, ctx1.timestamp) + + @mock.patch.object(llm_agent.logging, 'exception') + @mock.patch.object( + llm_agent.LLMAgent, 'get_scheduled_action_context', autospec=True + ) + @mock.patch.object(time, 'sleep', return_value=None) + def test_exceeds_max_retries_no_previous_action_falls_back_to_schedule( + self, mock_sleep, mock_get_scheduled, mock_exception + ): + self.mock_llm_service.get_response.side_effect = RuntimeError('OOPS') + + agent = llm_agent.LLMAgent( + env=self.env, + llm_service=self.mock_llm_service, + promptmaker_class=pm.Promptmaker, + max_tries=2, + ) + + scheduled_ctx = mock.MagicMock() + mock_get_scheduled.return_value = scheduled_ctx + + ctx = agent.get_action_context() + self.assertEqual(ctx, scheduled_ctx) + mock_get_scheduled.assert_called_once() + mock_exception.assert_called_once_with( + 'LLM MAX TRIES EXCEEDED. NO PREVIOUS LLM ACTION AVAILABLE. FALLING' + ' BACK TO SCHEDULED ACTION...' + ) + + +class LLMAgentNonActionContextSchemaTest(AlternativeSchemaTestBase): + + def setUp(self): + super().setUp() + self.agent = llm_agent.LLMAgent( + env=self.env, + llm_service=self.mock_llm_service, + promptmaker_class=pm.Promptmaker, + ) + + def test_validate_action_context_non_subclass(self): + self.agent.output_schema_class = output_schema.SetpointsAction + valid_response_json = schema_conftest.create_hybrid_action_response() + + ctx = self.agent.validate_action_context(valid_response_json) + + self.assertIsInstance(ctx, action_context.ActionContext) + + +if __name__ == '__main__': + absltest.main() diff --git a/smart_control/llm/agents/schedule_agent.py b/smart_control/llm/agents/schedule_agent.py new file mode 100644 index 00000000..1f443410 --- /dev/null +++ b/smart_control/llm/agents/schedule_agent.py @@ -0,0 +1,193 @@ +"""Schedule policy agent. + +This agent determines its actions based on the building's operational +schedule. Based on the current date and time, if the building is operational, +the agent will turn on all devices and use the environment's default setpoint +values. Otherwise, when the building is not operational, the agent will turn off +all devices. + +This agent requires a hybrid action environment, because it needs a mechanism +for turning devices on and off. +""" + +from typing import Final + +import numpy as np +from smart_buildings.smart_control.environment import hybrid_action_environment +from smart_buildings.smart_control.llm.agents import default_agent +from smart_buildings.smart_control.llm.schema import action_context +from smart_buildings.smart_control.llm.schema import output_schema +from smart_buildings.smart_control.llm.utils import schedule_tool as schedule_lib +from smart_buildings.smart_control.proto import smart_control_building_pb2 as building_pb2 +from smart_buildings.smart_control.proto import smart_control_reward_pb2 as reward_pb2 +from smart_buildings.smart_control.utils import serialization + +NATIVE_ON: Final[float] = 1.0 +NATIVE_OFF: Final[float] = 0.0 + + +class SchedulePolicyAgent(default_agent.DefaultPolicyAgent): + """An agent that determines its actions based on the building's schedule. + + Based on the current date and time, if the building is operational, the agent + will turn on all devices and use the environment's default setpoint values. + Otherwise, when the building is not operational, the agent will turn off all + devices. + + Because it is possible (but not common) for an environment's default values to + specify a device should be off, if you want to preserve that behavior and + prevent those devices from being turned on during operational hours, set the + `override_discrete_defaults` to `False`, and the agent will respect those + default values. + + This agent is to be used in conjunction with a hybrid action environment, so + it has a mechanism for turning devices ON or OFF. + + Attributes: + schedule_tool: The BuildingScheduleTool instance used to determine the + building's operational schedule. + override_discrete_defaults: Whether to override the environment's default + values for discrete actions (e.g., turning devices ON/OFF) during + operational hours. If True, discrete devices will be turned ON during + operational hours, even if the default value is OFF. If False, the + default discrete values are respected. + """ + + def __init__( + self, + *, + env: hybrid_action_environment.HybridActionEnvironment, + clip: bool = True, + override_discrete_defaults: bool = True, + schedule_tool: schedule_lib.BuildingScheduleTool | None = None, + ): + """Initializes the instance. + + Args: + env: The hybrid action environment the agent will interact with. + clip: Whether to clip setpoint values to the bounds of the valid range. If + `False`, raises `GuardrailsExceededError` if setpoint values are out of + range. Otherwise, clips the setpoint values to the valid range, and logs + a record of the error. Defaults to `True`. + override_discrete_defaults: Whether to override the default policy values + for discrete actions during operational hours. By default, the agent + will turn on all devices during operational hours, potentially + overriding any default values that specify a device should be off. If + you have default values that specify a device should be off during + operational hours, set this option to `False` and the agent will respect + those defaults. + schedule_tool: Optionally provide a BuildingScheduleTool instance. + Otherwise, a schedule tool will be constructed using the agent's + environment and default schedule tool arguments. + """ + super().__init__(env=env, clip=clip) + self.schedule_tool = schedule_tool or schedule_lib.BuildingScheduleTool( + env=env, + ) + self.override_discrete_defaults = override_discrete_defaults + + @property + def json_metadata(self) -> serialization.SerializableData: + """Info to write into a JSON file. Needs to be serializable.""" + return super().json_metadata | { + "override_discrete_defaults": self.override_discrete_defaults, + "schedule_policy": self.schedule_tool.json_metadata, + } + + @property + def building_operational_mode(self) -> schedule_lib.BuildingOperationalMode: + """The building's operational mode.""" + return self.schedule_tool.building_operational_mode + + @property + def building_is_operational(self) -> bool: + """Whether the building is operational.""" + return self.schedule_tool.building_is_operational + + @property + def scheduled_justification(self) -> str: + return f"Scheduled action ({self.building_operational_mode.value.upper()})" + + @property + def scheduled_setpoint_justification(self) -> str: + return f"Scheduled value ({self.building_operational_mode.value.upper()})" + + def get_scheduled_native_value( + self, setpoint_name: str, native_value: float + ) -> float: + """Determines the scheduled native value for a given setpoint. + + This method will flip the value of discrete actions to ON or OFF, depending + on whether the building is operational or not. + + Because it is possible (but not common) for an environment's default values + to specify a device should be off, if you want to preserve that behavior and + prevent those devices from being turned on during operational hours, set the + `override_discrete_defaults` to `False`, and the agent will respect those + default values. + + Args: + setpoint_name: The name of a given setpoint. + native_value: The native action value for the given setpoint. + + Returns: + The scheduled native action value for the setpoint. + """ + if not hybrid_action_environment.is_discrete_action(setpoint_name): + return native_value + + if self.building_is_operational: + return NATIVE_ON if self.override_discrete_defaults else native_value + return NATIVE_OFF + + def get_scheduled_action_context(self) -> action_context.ActionContext: + """Gets an action context based on the building's operational schedule. + + This action context uses the environment's default policy values as a base, + but ensures devices are turned off during non-operational hours, and on + during operational hours. + + Returns: + An action context representing the scheduled action. + """ + setpoints = [] + for action_name, normalized_value in zip( + self.env.action_names, self.env.default_policy_values, strict=True + ): + device_id, setpoint_name = self.env.id_map.inv[action_name] + + normalizer = self.env.action_normalizers.get(setpoint_name) + if normalizer is None: + raise ValueError(f"No normalizer found for setpoint: {setpoint_name}") + + native_value = normalizer.setpoint_value(np.array(normalized_value)) + scheduled_native_value = self.get_scheduled_native_value( + setpoint_name=setpoint_name, native_value=native_value + ) + + setpoints.append( + output_schema.DeviceSetpoint( + device_id=device_id, + setpoint_name=setpoint_name, + setpoint_value=scheduled_native_value, + justification=self.scheduled_setpoint_justification, + ) + ) + + return self.action_context_class( + env=self.env, + clip=self._clip, + timestamp=str(self.env.current_local_timestamp), + justification=self.scheduled_justification, + validity_interval=self.env.time_step_mins, + setpoints=setpoints, + ) + + def get_action_context( + self, + observation_response: building_pb2.ObservationResponse | None = None, + reward_info: reward_pb2.RewardInfo | None = None, + ) -> action_context.ActionContext: + """Returns the action context for the environment.""" + del observation_response, reward_info # Unused in this implementation. + return self.get_scheduled_action_context() diff --git a/smart_control/llm/agents/schedule_agent_test.py b/smart_control/llm/agents/schedule_agent_test.py new file mode 100644 index 00000000..404a496d --- /dev/null +++ b/smart_control/llm/agents/schedule_agent_test.py @@ -0,0 +1,169 @@ +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +from smart_buildings.smart_control.environment import conftest as env_conftest +from smart_buildings.smart_control.llm.agents import default_agent_test +from smart_buildings.smart_control.llm.agents import schedule_agent +from smart_buildings.smart_control.llm.schema import action_context +from smart_buildings.smart_control.llm.utils import schedule_tool +from smart_buildings.smart_control.llm.utils import schedule_tool_test + +TIME_ZONE = "US/Pacific" + + +class ScheduleHybridActionAgentTest( + default_agent_test.DefaultHybridActionAgentTest +): + + agent: schedule_agent.SchedulePolicyAgent + + def _create_environment( + self, start_timestamp=schedule_tool_test.CURRENT_LOCAL_TIMESTAMP + ): + return super()._create_environment(start_timestamp=start_timestamp) + + def _create_agent(self, env): + return schedule_agent.SchedulePolicyAgent(env=env) + + def test_json_metadata(self): + self.assertEqual( + self.agent.json_metadata, + { + "type": "SchedulePolicyAgent", + "default_policy": { + "action_names": [ + "air_handler_1_supply_air_heating_temperature_setpoint", + "air_handler_1_supervisor_run_command", + "boiler_1_supply_water_setpoint", + "boiler_1_supervisor_run_command", + "air_handler_2_supply_air_heating_temperature_setpoint", + "air_handler_2_supervisor_run_command", + ], + "default_values": [0.0, -1.0, -1.0, -1.0, 0.0, -1.0], + }, + "clip": True, + "override_discrete_defaults": True, + "schedule_policy": schedule_tool_test.SCHEDULE_METADATA, + }, + ) + + def test_building_is_operational(self): + self.assertTrue(self.agent.building_is_operational) + + def test_building_operational_mode(self): + self.assertEqual( + self.agent.building_operational_mode, + schedule_tool.BuildingOperationalMode.ON, + ) + + def test_justifications(self): + self.assertEqual( + self.agent.scheduled_justification, + "Scheduled action (ON)", + ) + self.assertEqual( + self.agent.scheduled_setpoint_justification, + "Scheduled value (ON)", + ) + + def test_scheduled_action_context(self): + ctx = self.agent.get_scheduled_action_context() + self.assertIsInstance(ctx, action_context.HybridActionContext) + + with self.subTest(name="timestamp"): + self.assertEqual(ctx.timestamp, str(self.env.current_local_timestamp)) + + with self.subTest(name="justification"): + self.assertEqual(ctx.justification, self.agent.scheduled_justification) + + with self.subTest(name="validity_interval"): + self.assertEqual(ctx.validity_interval, self.env.time_step_mins) + + with self.subTest(name="setpoints"): + self.assertLen(ctx.setpoints, len(self.env.action_names)) + + # Setpoint and device names should match the env's action names: + names = [(sp.device_id, sp.setpoint_name) for sp in ctx.setpoints] + self.assertEqual( + names, + [ + ("air_handler_1", "supply_air_heating_temperature_setpoint"), + ("air_handler_1", "supervisor_run_command"), + ("boiler_1", "supply_water_setpoint"), + ("boiler_1", "supervisor_run_command"), + ("air_handler_2", "supply_air_heating_temperature_setpoint"), + ("air_handler_2", "supervisor_run_command"), + ], + ) + + # Setpoint values should be native versions of the env's default values + # overridden by schedule policy. + setpoint_values = [sp.setpoint_value for sp in ctx.setpoints] + self.assertSequenceAlmostEqual( + setpoint_values, [290.0, 1.0, 310.0, 1.0, 290.0, 1.0] + ) + + def test_get_action_context(self): + ctx = self.agent.get_action_context() + self.assertIsInstance(ctx, action_context.HybridActionContext) + self.assertEqual(ctx, self.agent.get_scheduled_action_context()) + + +class ScheduleScenariosTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.env = env_conftest.create_hybrid_action_environment( + layout=env_conftest.DEMO_LAYOUT, + default_actions=env_conftest.DEFAULT_HYBRID_ACTIONS, + start_timestamp=schedule_tool_test.CURRENT_LOCAL_TIMESTAMP, + ) + + @parameterized.parameters( + # Parameters: + # setpoint_name, native_val, is_operational, will_override, expected_val + # ... + # Continuous setpoint, building is operational: + ("supply_air_heating_temperature_setpoint", 290.0, True, True, 290.0), + ("supply_air_heating_temperature_setpoint", 290.0, True, False, 290.0), + # Continuous setpoint, building is non-operational: + ("supply_air_heating_temperature_setpoint", 290.0, False, True, 290.0), + ("supply_air_heating_temperature_setpoint", 290.0, False, False, 290.0), + # Discrete action, building is non-operational + ("supervisor_run_command", 1.0, False, True, 0.0), + ("supervisor_run_command", 1.0, False, False, 0.0), + # Discrete action, building is operational, override defaults: + ("supervisor_run_command", 1.0, True, True, 1.0), # FLIPPED ON + ("supervisor_run_command", 0.0, True, True, 1.0), # FLIPPED ON + ("supervisor_run_command", -1.0, True, True, 1.0), # FLIPPED ON + # Discrete action, building is operational, do not override defaults: + ("supervisor_run_command", 1.0, True, False, 1.0), + ("supervisor_run_command", 0.0, True, False, 0.0), + ("supervisor_run_command", -1.0, True, False, -1.0), + ) + def test_get_scheduled_native_value( + self, + setpoint_name, + native_value, + building_is_operational, + override_discrete_defaults, + expected_value, + ): + mock_schedule_tool = mock.Mock() + type(mock_schedule_tool).building_is_operational = mock.PropertyMock( + return_value=building_is_operational + ) + agent = schedule_agent.SchedulePolicyAgent( + env=self.env, + schedule_tool=mock_schedule_tool, + override_discrete_defaults=override_discrete_defaults, + ) + self.assertEqual( + agent.get_scheduled_native_value(setpoint_name, native_value), + expected_value, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/smart_control/llm/llm_environment_test.py b/smart_control/llm/llm_environment_test.py new file mode 100644 index 00000000..c40c385f --- /dev/null +++ b/smart_control/llm/llm_environment_test.py @@ -0,0 +1,723 @@ +"""More tests for the environment, to ensure the LLM agent can use it.""" + +import dataclasses + +from absl.testing import absltest +from absl.testing import parameterized +import mock +import pandas as pd +from smart_buildings.smart_control.environment import conftest +from smart_buildings.smart_control.environment import environment +from smart_buildings.smart_control.environment import hybrid_action_environment +from smart_buildings.smart_control.models import base_building +from smart_buildings.smart_control.models import base_reward_function +from smart_buildings.smart_control.proto import smart_control_building_pb2 +from smart_buildings.smart_control.utils import bounded_action_normalizer +from smart_buildings.smart_control.utils import building_image_generator +from smart_buildings.smart_control.utils import controller_writer +from smart_buildings.smart_control.utils import observation_normalizer + + +class LLMEnvironmentTest(parameterized.TestCase): + """Ensures the environment has what it needs for an LLM agent use case.""" + + def setUp(self): + super().setUp() + self.env = conftest.create_environment( + default_actions=conftest.DEFAULT_ACTIONS + ) + + def test_initialization(self): + self.assertIsInstance(self.env, environment.Environment) + + with self.subTest(name="building"): + self.assertIsInstance(self.env.building, base_building.BaseBuilding) + + with self.subTest(name="reward_function"): + self.assertIsInstance( + self.env.reward_function, base_reward_function.BaseRewardFunction + ) + + with self.subTest(name="observation_normalizer"): + self.assertIsInstance( + self.env.observation_normalizer, + observation_normalizer.StandardScoreObservationNormalizer, + ) + + with self.subTest(name="action_normalizers"): + self.assertIsInstance(self.env.action_normalizers, dict) + + bounded_normalizers = [ + isinstance(n, bounded_action_normalizer.BoundedActionNormalizer) + for n in self.env.action_normalizers.values() + ] + self.assertTrue(all(bounded_normalizers)) + + with self.subTest(name="default_actions"): + self.assertEqual( + self.env.default_action_values, conftest.DEFAULT_ACTION_VALUES + ) + + def test_properties(self): + with self.subTest(name="step_count"): + self.assertEqual(self.env.step_count, 0) + + with self.subTest(name="time_step_sec"): + self.assertEqual(self.env.time_step_sec, 300) + + with self.subTest(name="time_step_mins"): + self.assertEqual(self.env.time_step_mins, 5) + + with self.subTest(name="time_zone"): + self.assertEqual(self.env.time_zone, "US/Pacific") + + with self.subTest(name="current_simulation_timestamp"): + ts = self.env.current_simulation_timestamp + self.assertIsNone(ts.tz) + self.assertEqual(ts, pd.Timestamp("2021-06-07 12:00:01")) + + with self.subTest(name="current_local_timestamp"): + ts = self.env.current_local_timestamp + self.assertEqual(ts, pd.Timestamp("2021-06-07 12:00:01", tz="US/Pacific")) + + with self.subTest(name="json_metadata"): + expected_metadata = { + "type": "Environment", + "time_step_sec": 300.0, + "start_timestamp": "2021-06-07 12:00:01", + "end_timestamp": "2021-06-10 12:00:01", + "metrics_output_dir": None, + "action_names": [ + "air_handler_1_supply_air_heating_temperature_setpoint", + "boiler_1_supply_water_setpoint", + "air_handler_2_supply_air_heating_temperature_setpoint", + ], + "default_action_values": [0.0, -1.0, 0.0], + "reward_function": {"type": "SimpleRewardFunction"}, + "building": { + "n_devices": 4, + "n_zones": 2, + "device_ids": [ + "air_handler_1", + "boiler_1", + "air_handler_2", + "vav_1", + ], + "zone_ids": ["zone_1", "zone_2"], + }, + "occupancy": None, + } + self.assertEqual(self.env.json_metadata, expected_metadata) + + def test_json_metadata_with_occupancy(self): + self.env.building.occupancy = mock.MagicMock() + occupancy_metadata = {"type": "MockOccupancyModel"} + self.env.building.occupancy.json_metadata = occupancy_metadata + self.assertEqual(self.env.json_metadata["occupancy"], occupancy_metadata) + + def test_building_devices(self): + df = self.env.building.devices_df + self.assertIsInstance(df, pd.DataFrame) + + expected_records = [ + { + "device_id": "air_handler_1", + "namespace": "", + "code": "", + "zone_id": "zone_1", + "device_type_id": smart_control_building_pb2.DeviceInfo.AHU, + "device_type": "AHU", + }, + { + "device_id": "boiler_1", + "namespace": "", + "code": "", + "zone_id": "zone_1", + "device_type_id": smart_control_building_pb2.DeviceInfo.BLR, + "device_type": "BLR", + }, + { + "device_id": "air_handler_2", + "namespace": "", + "code": "", + "zone_id": "zone_2", + "device_type_id": smart_control_building_pb2.DeviceInfo.AHU, + "device_type": "AHU", + }, + { + "device_id": "vav_1", + "namespace": "", + "code": "", + "zone_id": "zone_2", + "device_type_id": smart_control_building_pb2.DeviceInfo.VAV, + "device_type": "VAV", + }, + ] + self.assertEqual(df.to_dict("records"), expected_records) + + def test_building_zones(self): + df = self.env.building.zones_df + self.assertIsInstance(df, pd.DataFrame) + + expected_records = [ + { + "building_id": "SimpleBuilding", + "zone_id": "zone_1", + "zone_type_id": smart_control_building_pb2.ZoneInfo.UNDEFINED, + "zone_type": "UNDEFINED", + "description": "zone_1", + "area": 0.0, + "floor": 0, + "device_ids": ["air_handler_1", "boiler_1"], + }, + { + "building_id": "SimpleBuilding", + "zone_id": "zone_2", + "zone_type_id": smart_control_building_pb2.ZoneInfo.UNDEFINED, + "zone_type": "UNDEFINED", + "description": "zone_2", + "area": 0.0, + "floor": 0, + "device_ids": ["air_handler_2", "vav_1"], + }, + ] + self.assertEqual(df.to_dict("records"), expected_records) + + def test_action_fields_df(self): + self.assertIsInstance(self.env.action_fields_df, pd.DataFrame) + records = self.env.action_fields_df.to_dict("records") + expected_records = [ + { + "action_name": "air_handler_1_supply_air_heating_temperature_setpoint", # pylint: disable=line-too-long + "device_id": "air_handler_1", + "device_type": "AHU", + "zone_id": "zone_1", + "setpoint_name": "supply_air_heating_temperature_setpoint", + "value_type": "VALUE_CONTINUOUS", + "setpoint_type": "CONTINUOUS", + "units": "Kelvin", + "max_native_value": 295.0, + "max_normalized_value": 1.0, + "min_native_value": 285.0, + "min_normalized_value": -1.0, + }, + { + "action_name": "boiler_1_supply_water_setpoint", + "device_id": "boiler_1", + "device_type": "BLR", + "zone_id": "zone_1", + "setpoint_name": "supply_water_setpoint", + "value_type": "VALUE_CONTINUOUS", + "setpoint_type": "CONTINUOUS", + "units": "Kelvin", + "max_native_value": 350.0, + "max_normalized_value": 1.0, + "min_native_value": 310.0, + "min_normalized_value": -1.0, + }, + { + "action_name": "air_handler_2_supply_air_heating_temperature_setpoint", # pylint: disable=line-too-long + "device_id": "air_handler_2", + "device_type": "AHU", + "zone_id": "zone_2", + "setpoint_name": "supply_air_heating_temperature_setpoint", + "value_type": "VALUE_CONTINUOUS", + "setpoint_type": "CONTINUOUS", + "units": "Kelvin", + "max_native_value": 295.0, + "max_normalized_value": 1.0, + "min_native_value": 285.0, + "min_normalized_value": -1.0, + }, + ] + self.assertCountEqual(records, expected_records) + + @parameterized.named_parameters( + dict( + testcase_name="records_from_normalized_values", + method_name="get_action_records_from_normalized_values", + values=conftest.NORMALIZED_ACTION_VALUES, + expect_df=False, + ), + dict( + testcase_name="df_from_normalized_values", + method_name="get_action_df_from_normalized_values", + values=conftest.NORMALIZED_ACTION_VALUES, + expect_df=True, + ), + dict( + testcase_name="records_from_native_values", + method_name="get_action_records_from_native_values", + values=conftest.NATIVE_ACTION_VALUES, + expect_df=False, + ), + dict( + testcase_name="df_from_native_values", + method_name="get_action_df_from_native_values", + values=conftest.NATIVE_ACTION_VALUES, + expect_df=True, + ), + ) + def test_get_action_records(self, method_name, values, expect_df): + result = getattr(self.env, method_name)(values) + if expect_df: + self.assertCountEqual(result.to_dict("records"), conftest.ACTION_RECORDS) + else: + self.assertCountEqual( + [dataclasses.asdict(r) for r in result], + conftest.ACTION_RECORDS, + ) + + def test_step(self): + self.env.reset() + self.assertEqual(self.env.step_count, 0) + self.env.step([0, 0, 0]) # normalized action values + self.assertEqual(self.env.step_count, 1) + + def test_step_with_defaults(self): + self.env.reset() + self.assertEqual(self.env.step_count, 0) + self.env.step(self.env.default_action_values) + self.assertEqual(self.env.step_count, 1) + + def test_observations(self): + n_device_measurements = 4 # see all "_measurement" in conftest.LAYOUT + n_auxiliary_measurements = 7 + n_observations = n_device_measurements + n_auxiliary_measurements + self.assertEqual( + self.env.observation_spec(), + conftest.create_observation_spec(n_observations), + ) + + def test_actions(self): + self.assertEqual( + self.env.action_spec(), conftest.create_action_spec(n_continuous=3) + ) + self.assertSequenceEqual( + self.env.action_names, + [ + "air_handler_1_supply_air_heating_temperature_setpoint", + "boiler_1_supply_water_setpoint", + "air_handler_2_supply_air_heating_temperature_setpoint", + ], + ) + + +class LLMHybridActionEnvironmentTest(parameterized.TestCase): + """Ensures the environment has what it needs for an LLM agent use case.""" + + def setUp(self): + super().setUp() + self.env = conftest.create_hybrid_action_environment( + layout=conftest.DEMO_LAYOUT, + default_actions=conftest.DEFAULT_HYBRID_ACTIONS, + ) + + def test_initialization(self): + self.assertIsInstance( + self.env, hybrid_action_environment.HybridActionEnvironment + ) + + with self.subTest(name="building"): + self.assertIsInstance(self.env.building, base_building.BaseBuilding) + + with self.subTest(name="reward_function"): + self.assertIsInstance( + self.env.reward_function, base_reward_function.BaseRewardFunction + ) + + with self.subTest(name="observation_normalizer"): + self.assertIsInstance( + self.env.observation_normalizer, + observation_normalizer.StandardScoreObservationNormalizer, + ) + + with self.subTest(name="action_normalizers"): + self.assertIsInstance(self.env.action_normalizers, dict) + + self.assertEqual( + set(type(n) for n in self.env.action_normalizers.values()), + {bounded_action_normalizer.BoundedActionNormalizer}, + ) + + with self.subTest(name="default_actions"): + self.assertEqual( + self.env.default_action_values, conftest.DEFAULT_HYBRID_ACTION_VALUES + ) + self.assertEqual( + self.env.default_hybrid_action, conftest.DEFAULT_HYBRID_ACTION_DICT + ) + + def test_building_devices(self): + df = self.env.building.devices_df + self.assertIsInstance(df, pd.DataFrame) + + expected_records = [ + { + "device_id": "air_handler_1", + "namespace": "", + "code": "", + "zone_id": "zone_1", + "device_type_id": smart_control_building_pb2.DeviceInfo.AHU, + "device_type": "AHU", + }, + { + "device_id": "boiler_1", + "namespace": "", + "code": "", + "zone_id": "zone_1", + "device_type_id": smart_control_building_pb2.DeviceInfo.BLR, + "device_type": "BLR", + }, + { + "device_id": "air_handler_2", + "namespace": "", + "code": "", + "zone_id": "zone_2", + "device_type_id": smart_control_building_pb2.DeviceInfo.AHU, + "device_type": "AHU", + }, + { + "device_id": "outside_air_sensor", + "namespace": "", + "code": "", + "zone_id": "zone_2", + "device_type_id": smart_control_building_pb2.DeviceInfo.UNDEFINED, + "device_type": "UNDEFINED", + }, + ] + self.assertEqual(df.to_dict("records"), expected_records) + + def test_building_zones(self): + df = self.env.building.zones_df + self.assertIsInstance(df, pd.DataFrame) + + expected_records = [ + { + "building_id": "SimpleBuilding", + "zone_id": "zone_1", + "zone_type_id": smart_control_building_pb2.ZoneInfo.UNDEFINED, + "zone_type": "UNDEFINED", + "description": "zone_1", + "area": 0.0, + "floor": 0, + "device_ids": ["air_handler_1", "boiler_1"], + }, + { + "building_id": "SimpleBuilding", + "zone_id": "zone_2", + "zone_type_id": smart_control_building_pb2.ZoneInfo.UNDEFINED, + "zone_type": "UNDEFINED", + "description": "zone_2", + "area": 0.0, + "floor": 0, + "device_ids": ["air_handler_2", "outside_air_sensor"], + }, + ] + self.assertEqual(df.to_dict("records"), expected_records) + + def test_properties(self): + with self.subTest(name="time_zone"): + self.assertEqual(self.env.time_zone, "US/Pacific") + + with self.subTest(name="current_simulation_timestamp"): + self.assertEqual( + self.env.current_simulation_timestamp, + pd.Timestamp("2021-06-07 12:00:01"), + ) + + with self.subTest(name="step_count"): + self.assertEqual(self.env.step_count, 0) + + with self.subTest(name="json_metadata"): + expected_metadata = { + "type": "HybridActionEnvironment", + "time_step_sec": 300.0, + "start_timestamp": "2021-06-07 12:00:01", + "end_timestamp": "2021-06-10 12:00:01", + "metrics_output_dir": None, + "action_names": [ + "air_handler_1_supply_air_heating_temperature_setpoint", + "air_handler_1_supervisor_run_command", + "boiler_1_supply_water_setpoint", + "boiler_1_supervisor_run_command", + "air_handler_2_supply_air_heating_temperature_setpoint", + "air_handler_2_supervisor_run_command", + ], + "default_action_values": [0.0, -1.0, -1.0, -1.0, 0.0, -1.0], + "reward_function": {"type": "SimpleRewardFunction"}, + "building": { + "n_devices": 4, + "n_zones": 2, + "device_ids": [ + "air_handler_1", + "boiler_1", + "air_handler_2", + "outside_air_sensor", + ], + "zone_ids": ["zone_1", "zone_2"], + }, + "occupancy": None, + } + self.assertEqual(self.env.json_metadata, expected_metadata) + + def test_action_fields_df(self): + df = self.env.action_fields_df + self.assertIsInstance(df, pd.DataFrame) + expected_records = [ + { + "action_name": "air_handler_1_supervisor_run_command", + "device_id": "air_handler_1", + "device_type": "AHU", + "zone_id": "zone_1", + "setpoint_name": "supervisor_run_command", + "value_type": "VALUE_CONTINUOUS", + "setpoint_type": "DISCRETE", + "units": "On/Off", + "max_native_value": 1.0, + "max_normalized_value": 1.0, + "min_native_value": 0.0, + "min_normalized_value": -1.0, + }, + { + "action_name": "air_handler_1_supply_air_heating_temperature_setpoint", # pylint: disable=line-too-long + "device_id": "air_handler_1", + "device_type": "AHU", + "zone_id": "zone_1", + "setpoint_name": "supply_air_heating_temperature_setpoint", + "value_type": "VALUE_CONTINUOUS", + "setpoint_type": "CONTINUOUS", + "units": "Kelvin", + "max_native_value": 295.0, + "max_normalized_value": 1.0, + "min_native_value": 285.0, + "min_normalized_value": -1.0, + }, + { + "action_name": "boiler_1_supervisor_run_command", + "device_id": "boiler_1", + "device_type": "BLR", + "zone_id": "zone_1", + "setpoint_name": "supervisor_run_command", + "value_type": "VALUE_CONTINUOUS", + "setpoint_type": "DISCRETE", + "units": "On/Off", + "max_native_value": 1.0, + "max_normalized_value": 1.0, + "min_native_value": 0.0, + "min_normalized_value": -1.0, + }, + { + "action_name": "boiler_1_supply_water_setpoint", + "device_id": "boiler_1", + "device_type": "BLR", + "zone_id": "zone_1", + "setpoint_name": "supply_water_setpoint", + "value_type": "VALUE_CONTINUOUS", + "setpoint_type": "CONTINUOUS", + "units": "Kelvin", + "max_native_value": 350.0, + "max_normalized_value": 1.0, + "min_native_value": 310.0, + "min_normalized_value": -1.0, + }, + { + "action_name": "air_handler_2_supervisor_run_command", + "device_id": "air_handler_2", + "device_type": "AHU", + "zone_id": "zone_2", + "setpoint_name": "supervisor_run_command", + "value_type": "VALUE_CONTINUOUS", + "setpoint_type": "DISCRETE", + "units": "On/Off", + "max_native_value": 1.0, + "max_normalized_value": 1.0, + "min_native_value": 0.0, + "min_normalized_value": -1.0, + }, + { + "action_name": "air_handler_2_supply_air_heating_temperature_setpoint", # pylint: disable=line-too-long + "device_id": "air_handler_2", + "device_type": "AHU", + "zone_id": "zone_2", + "setpoint_name": "supply_air_heating_temperature_setpoint", + "value_type": "VALUE_CONTINUOUS", + "setpoint_type": "CONTINUOUS", + "units": "Kelvin", + "max_native_value": 295.0, + "max_normalized_value": 1.0, + "min_native_value": 285.0, + "min_normalized_value": -1.0, + }, + ] + self.assertCountEqual(df.to_dict("records"), expected_records) + + @parameterized.named_parameters( + dict( + testcase_name="records_from_normalized_values", + method_name="get_action_records_from_normalized_values", + values=conftest.NORMALIZED_HYBRID_ACTION_VALUES, + expect_df=False, + ), + dict( + testcase_name="df_from_normalized_values", + method_name="get_action_df_from_normalized_values", + values=conftest.NORMALIZED_HYBRID_ACTION_VALUES, + expect_df=True, + ), + dict( + testcase_name="records_from_native_values", + method_name="get_action_records_from_native_values", + values=conftest.NATIVE_HYBRID_ACTION_VALUES, + expect_df=False, + ), + dict( + testcase_name="df_from_native_values", + method_name="get_action_df_from_native_values", + values=conftest.NATIVE_HYBRID_ACTION_VALUES, + expect_df=True, + ), + ) + def test_get_action_records(self, method_name, values, expect_df): + result = getattr(self.env, method_name)(values) + if expect_df: + self.assertCountEqual( + result.to_dict("records"), conftest.HYBRID_ACTION_RECORDS + ) + else: + self.assertCountEqual( + [dataclasses.asdict(r) for r in result], + conftest.HYBRID_ACTION_RECORDS, + ) + + def test_step(self): + self.env.reset() + self.assertEqual(self.env.step_count, 0) + self.env.step({ + "discrete_action": [0, 0, 0], + "continuous_action": [-1.0, 0.0, 1.0], + }) + self.assertEqual(self.env.step_count, 1) + + def test_step_with_defaults(self): + self.env.reset() + self.assertEqual(self.env.step_count, 0) + self.env.step(self.env.default_hybrid_action) + self.assertEqual(self.env.step_count, 1) + + def test_convert_to_hybrid(self): + action_values = [-1.0, -1.0, 0.0, 1.0, 1.0, 1.0] + expected_hybrid_action = { + "discrete_action": [0.0, 1.0, 1.0], + "continuous_action": [-1.0, 0.0, 1.0], + } + hybrid_action = self.env.convert_to_hybrid(action_values) + self.assertEqual(hybrid_action, expected_hybrid_action) + + def test_observations(self): + n_device_measurements = 1 # see all "_measurement" in conftest.DEMO_LAYOUT + n_auxiliary_measurements = 7 + n_observations = n_device_measurements + n_auxiliary_measurements + with self.subTest(name="observation_spec"): + self.assertEqual( + self.env.observation_spec(), + conftest.create_observation_spec(n_observations), + ) + + def test_actions(self): + with self.subTest(name="action_spec"): + self.assertEqual( + self.env.action_spec(), + conftest.create_hybrid_action_spec(n_continuous=3, n_discrete=3), + ) + + with self.subTest(name="action_names"): + self.assertSequenceEqual( + self.env.action_names, + [ + "air_handler_1_supply_air_heating_temperature_setpoint", + "air_handler_1_supervisor_run_command", + "boiler_1_supply_water_setpoint", + "boiler_1_supervisor_run_command", + "air_handler_2_supply_air_heating_temperature_setpoint", + "air_handler_2_supervisor_run_command", + ], + ) + + +# +# METRICS WRITER TESTS +# + + +class EnvironmentMetricsWriterTest(parameterized.TestCase): + """Ensures the environment metrics are written.""" + + def setUp(self): + super().setUp() + self.metrics_path = self.create_tempdir().full_path + writer_factory = controller_writer.ProtoWriterFactory() + self.env = conftest.create_environment( + metrics_path=self.metrics_path, writer_factory=writer_factory + ) + + def test_metrics_writer(self): + self.assertIsInstance( + self.env._metrics_writer, controller_writer.ProtoWriter + ) + self.assertStartsWith( + self.env._metrics_writer._output_dir, self.metrics_path + ) + + def test_reset_writes_metrics(self): + # the setup for this test is a little more complex, since the reset() method + # creates a new metrics writer... + # so we are mocking the writer_factory.create method to return a mock writer + writer = mock.create_autospec(controller_writer.ProtoWriter, instance=True) + + with mock.patch.object( + self.env._writer_factory, "create", return_value=writer, autospec=True + ) as mock_create_method: + self.env.reset() + + mock_create_method.assert_called_once() + writer.write_device_infos.assert_called_once_with(self.env.building.devices) + writer.write_zone_infos.assert_called_once_with(self.env.building.zones) + + @parameterized.parameters("get_reward", "get_reward_info") + def test_reward_methods_write_metrics(self, method_name): + self.env._metrics_writer = mock.Mock() + + getattr(self.env, method_name)() + + with self.subTest(name="writes reward_info"): + self.env._metrics_writer.write_reward_info.assert_called_once() + + with self.subTest(name="writes reward_response"): + self.env._metrics_writer.write_reward_response.assert_called_once() + + @parameterized.parameters("get_observation_response", "_get_observation") + def test_observation_methods_write_metrics(self, method_name): + self.env._metrics_writer = mock.Mock() + self.env._building_image_generator = mock.create_autospec( + building_image_generator.BuildingImageGenerator, instance=True + ) + + getattr(self.env, method_name)() + + with self.subTest(name="writes observation_response"): + self.env._metrics_writer.write_observation_response.assert_called_once() + + with self.subTest(name="writes building image if generator is set"): + self.env._metrics_writer.write_building_image.assert_called_once() + + def test_step_writes_metrics(self): + self.env._metrics_writer = mock.Mock() + + self.env.step([0, 0, 0]) + + self.env._metrics_writer.write_action_response.assert_called_once() + + +if __name__ == "__main__": + absltest.main() diff --git a/smart_control/llm/loop/conftest.py b/smart_control/llm/loop/conftest.py new file mode 100644 index 00000000..15cb9bed --- /dev/null +++ b/smart_control/llm/loop/conftest.py @@ -0,0 +1,54 @@ +"""Factories and helpers for control loop tests.""" + +from unittest import mock + +import pandas as pd +from smart_buildings.smart_control.environment import conftest as env_conftest +from smart_buildings.smart_control.llm.agents import default_agent +from smart_buildings.smart_control.llm.loop import control_loop +from smart_buildings.smart_control.utils import writer_lib + +START_TIMESTAMP = pd.Timestamp('2025-12-12 00:00:00', tz='US/Pacific') + + +def create_loop( + start_timestamp: pd.Timestamp = START_TIMESTAMP, + loop_class: type[control_loop.ControlLoop] = control_loop.ControlLoop, + max_steps: int | None = 5, + hybrid: bool = True, + agent: default_agent.DefaultPolicyAgent | None = None, +) -> control_loop.ControlLoop: + """Creates a control loop, with a default agent, for testing purposes. + + Args: + start_timestamp: The start timestamp for the environment / building. + loop_class: The class of the loop to be created. + max_steps: The maximum number of steps to run the loop for. + hybrid: Whether to create a hybrid action environment. Default is True. + agent: The agent to use for the loop. A default agent will be created if + None. + + Returns: + A control loop, for testing purposes. + """ + + if hybrid: + env = env_conftest.create_hybrid_action_environment( + layout=env_conftest.DEMO_LAYOUT, + start_timestamp=start_timestamp, + default_actions=env_conftest.DEFAULT_HYBRID_ACTIONS, + ) + else: + env = env_conftest.create_environment( + layout=env_conftest.DEMO_LAYOUT, + start_timestamp=start_timestamp, + default_actions=env_conftest.DEFAULT_ACTIONS, + ) + + env._metrics_writer = mock.create_autospec( # pylint: disable=protected-access + writer_lib.BaseWriter, instance=True + ) + + agent = agent or default_agent.DefaultPolicyAgent(env=env, clip=True) + + return loop_class(agent=agent, max_steps=max_steps) diff --git a/smart_control/llm/loop/control_loop.py b/smart_control/llm/loop/control_loop.py new file mode 100644 index 00000000..bcc939a1 --- /dev/null +++ b/smart_control/llm/loop/control_loop.py @@ -0,0 +1,334 @@ +"""Agent control loop. + +The loop is a harness / driver to facilitate an agent's control of an +environment. The loop is responsible for getting observations from the +environment, getting actions from the agent, and stepping the environment to +apply those actions to the building. + +The loop runs a single episode, covering a specified number of days according to +the environment's configuration. It steps the environment on a regular time step +interval (usually every five minutes), as specified by the environment's +configuration. + +**Validity Interval** + +Some agents (like RL agents and baseline agents) may take actions +every time step, while others (like LLM agents) may choose to specify longer +validity intervals. The validity interval is the amount of time for which an +action is valid (i.e. the amount of time to wait before asking the agent for +another action). While the loop is waiting for the validity interval to expire, +it will apply the most recent action it has received, to step the environment +during every time step until the validity interval runs out. + +The loop will step the environment at every time step, but will only ask the +agent for a new action when the validity interval runs out. Agents like +baseline agents or RL agents that don't vary their validity intervals can use +the environment's time step interval in minutes, as a fixed default interval for +every action. Other agents like LLM agents may choose to specify longer validity +intervals, based on building conditions - for example an agent may choose to +wait two hours between actions, at night when conditions are stable and there +are no occupants in the building. + +The validity interval also acts as a cost-saving measure, as it can reduce the +number of API calls to the LLM (from around 288 to around 25 per day). + +**Action Context** + +The agent provides an action context object to the loop, which the loop uses to +step the environment. The action context contains the action itself, as well as +more context about the action, suchj as the validity interval, and +justifications / reasoning, as applicable. + +**Max Steps** + +The loop can be stopped early if a maximum number of steps is specified. This is +helpful for testing and debugging purposes. + +**Metrics** + +The basic control loop uses existing metrics writing functionality, triggering +protos to be written to file during each time step (see environment's methods to +get information about observations and rewards). +""" + +import logging +from typing import Any, Final + +import numpy as np +import pandas as pd +from smart_buildings.smart_control.llm.agents import default_agent +from smart_buildings.smart_control.proto import smart_control_reward_pb2 as reward_pb2 +from smart_buildings.smart_control.utils import writer_lib +from tf_agents.trajectories import time_step as ts + +SerializableData = dict[str, Any] + +ACTION_REJECTION_REWARD: Final[float] = -np.inf + + +def get_clock_timestamp() -> pd.Timestamp: + """Returns the actual current clock timestamp.""" + return pd.Timestamp.now().replace(microsecond=0, nanosecond=0) + + +def parse_timestamp(timestamp: pd.Timestamp, time_zone: str) -> pd.Timestamp: + """Ensures that a timestamp is timezone-aware.""" + if timestamp.tzinfo is None: + return timestamp.tz_localize(time_zone) + return timestamp.tz_convert(time_zone) + + +class ControlLoop: + """An agentic control loop. + + The loop is responsible for stepping the environment on a regular basis. + + The agent is called to get an action whenever the validity interval runs out. + + If a maximum number of steps is specified, the loop will stop running after + that number of steps. + + The loop will keep track of the agent's cumulative rewards over time. + + Attributes: + agent: The agent to use for the loop. + env: The environment to use for the loop. + metrics_writer: The metrics writer to use for the loop. + max_steps: The maximum number of steps to run the loop for. + cum_reward: The cumulative reward for the loop. + results: The results of the loop. + """ + + def __init__( + self, + agent: default_agent.DefaultPolicyAgent, + max_steps: int | None = None, + ): + """Initializes the instance. + + Args: + agent: The agent to use for the loop. + max_steps: The maximum number of steps to run the loop for. If None, the + loop will run until the environment has ended. + """ + self.agent = agent + self.env = self.agent.env + self.metrics_writer = self._validate_metrics_writer(self.env.metrics_writer) + + self.max_steps = max_steps + + self.cum_reward = 0.0 + self.results = [] + + def _interval_has_expired(self, remaining_interval: pd.Timedelta) -> bool: + """Checks whether the validity interval has expired. + + If so, it is time to get a new action from the agent. + + Args: + remaining_interval: timedelta representing the remaining interval to wait + before getting a new action from the agent. + + Returns: + Whether or not the interval has expired. + """ + return remaining_interval <= self.time_step_interval + + def _max_steps_reached(self, max_step: int | None) -> bool: + return max_step is not None and self.current_step >= max_step + + def _action_rejected(self, time_step: ts.TimeStep) -> bool: + """Checks whether the action was rejected by the environment.""" + return (time_step.reward == ACTION_REJECTION_REWARD).any() + + # + # MAIN LOOP + # + + def run(self) -> None: + """Runs the control loop for a single episode.""" + self.write_metadata() + + max_step = ( + self.current_step + self.max_steps + if self.max_steps is not None + else None + ) + + # GET INITIAL STATE + + observation_response = self.env.get_observation_response() + reward_info, reward_response = self.env.get_reward_info_and_response() + + # GET INITIAL AGENT ACTION + + action_ctx = self.agent.get_action_context( + observation_response=observation_response, + reward_info=reward_info, + ) + action = action_ctx.get_action() + remaining_interval = pd.Timedelta(minutes=action_ctx.validity_interval) + + while True: + if self.episode_has_ended: + logging.info("EPISODE HAS ENDED. STOPPING...") + break + + if self._max_steps_reached(max_step): + logging.info("MAX STEPS REACHED. STOPPING...") + break + + # STEP THE ENV (USING WHATEVER ACTION IT HAS MOST RECENTLY RECEIVED) + + time_step = self.env.step(action) + if self._action_rejected(time_step): + logging.warning("ACTION REJECTED BY THE ENVIRONMENT.") + + reward = time_step.reward.item() + self.cum_reward += float(reward) + logging.info("REWARD: %r --> %r", reward, self.cum_reward) + + # UPDATE RESULTS + + self.update_results( + reward=reward, + reward_info=reward_info, + reward_response=reward_response, + ) + + # GET NEW STATE + + observation_response = self.env.get_observation_response() + reward_info, reward_response = self.env.get_reward_info_and_response() + + # UPDATE ACTION (AS NECESSARY) + + if self._interval_has_expired(remaining_interval): + # VALIDITY INTERVAL HAS EXPIRED. GET A NEW ACTION FROM THE AGENT. + action_ctx = self.agent.get_action_context( + observation_response=observation_response, + reward_info=reward_info, + ) + action = action_ctx.get_action() + remaining_interval = pd.Timedelta(minutes=action_ctx.validity_interval) + else: + # CONTINUE WAITING FOR VALIDITY INTERVAL TO EXPIRE + remaining_interval -= self.time_step_interval + + # EPISODE HAS ENDED + + self.write_results() + + # + # ENVIRONMENT PROPERTIES + # + + @property + def start_timestamp(self) -> pd.Timestamp: + """The start timestamp, in environment's local time zone.""" + return parse_timestamp(self.env.start_timestamp, self.env.time_zone) + + @property + def end_timestamp(self) -> pd.Timestamp: + """The end timestamp, in the environment's local time zone.""" + return parse_timestamp(self.env.end_timestamp, self.env.time_zone) + + @property + def days_per_episode(self) -> int: + """The number of steps per episode.""" + return self.env.num_days_in_episode + + @property + def time_step_interval(self) -> pd.Timedelta: + """The time step in minutes, as a pandas Timedelta.""" + return pd.Timedelta(minutes=self.env.time_step_mins) + + @property + def steps_per_day(self) -> int: + """The number of steps per day.""" + return int(pd.Timedelta(days=1) / self.time_step_interval) + + @property + def steps_per_episode(self) -> int: + """The number of steps per episode.""" + return self.env._num_timesteps_in_episode # pylint: disable=protected-access + + @property + def episode_has_ended(self) -> bool: + """Whether the episode has ended.""" + return self.env._has_episode_ended() # pylint: disable=protected-access + + @property + def current_step(self) -> int: + """The current step number.""" + return self.env._step_count # pylint: disable=protected-access + + @property + def current_local_timestamp(self) -> pd.Timestamp: + """The current local timestamp.""" + return self.env.current_local_timestamp + + # + # METRICS + # + + def _validate_metrics_writer( + self, writer: writer_lib.BaseWriter + ) -> writer_lib.BaseWriter: + """Validates the metrics writer.""" + if writer is None: + raise ValueError("Metrics writer is None.") + + if not hasattr(writer, "output_dir"): + raise ValueError("Metrics writer does not have output_dir attribute.") + + if not hasattr(writer, "write_json"): + raise ValueError("Metrics writer does not have write_json method.") + + return writer + + @property + def metrics_output_dir(self) -> Any: + """The directory to write metrics to.""" + return self.metrics_writer.output_dir + + def write_metadata(self) -> None: + """Writes the metadata to a file (before running the loop).""" + self.metrics_writer.write_json(self.json_metadata, "metadata.json") + + @property + def json_metadata(self) -> SerializableData: + """Info about the loop's initial state and input parameters.""" + return { + "start_timestamp": str(self.start_timestamp), + "end_timestamp": str(self.end_timestamp), + "days_per_episode": self.days_per_episode, + "time_step_mins": self.env.time_step_mins, + "steps_per_episode": self.steps_per_episode, + "env": self.env.json_metadata, + "agent": self.agent.json_metadata, + } + + def update_results( + self, + reward: float, + reward_info: reward_pb2.RewardInfo, + reward_response: reward_pb2.RewardResponse, + ) -> None: + """Updates the results (after the current step has completed).""" + pass + + def write_results(self) -> None: + """Writes the results to a file (after the episode has completed).""" + self.metrics_writer.write_json(self.json_results, "results.json") + + @property + def json_results(self) -> SerializableData: + """Info about the loop's current / final state, after it has begun.""" + return { + "clock_timestamp": str(get_clock_timestamp()), + "current_timestamp": str(self.current_local_timestamp), + "current_step": self.current_step, + "cum_reward": self.cum_reward, + "results": self.results, + } diff --git a/smart_control/llm/loop/control_loop_test.py b/smart_control/llm/loop/control_loop_test.py new file mode 100644 index 00000000..fcaec2da --- /dev/null +++ b/smart_control/llm/loop/control_loop_test.py @@ -0,0 +1,349 @@ +from unittest import mock + +from absl.testing import absltest +import numpy as np +import pandas as pd +# pylint: disable=g-bad-import-order local package imports in their own section below third party packages +from smart_buildings.smart_control.environment import conftest as env_conftest +from smart_buildings.smart_control.environment import hybrid_action_environment +from smart_buildings.smart_control.llm.agents import default_agent +from smart_buildings.smart_control.llm.loop import conftest +from smart_buildings.smart_control.llm.loop import control_loop +from smart_buildings.smart_control.utils import writer_lib +from tf_agents.trajectories import time_step as ts + +CLOCK_TIMESTAMP = pd.Timestamp('2026-03-26 12:00:00') +EXAMPLE_TIME_STEP = ts.TimeStep( + step_type=ts.StepType.MID, + reward=np.array([10.0]), + discount=np.array(1.0), + observation=(), +) + + +class ClockTimestampTest(absltest.TestCase): + + def test_get_clock_timestamp(self): + with mock.patch.object( + pd.Timestamp, 'now', return_value=CLOCK_TIMESTAMP, autospec=True + ): + self.assertEqual( + control_loop.get_clock_timestamp(), + CLOCK_TIMESTAMP, + ) + + +class TimestampParserTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.time_zone = 'US/Pacific' + + def test_parse_timestamp_without_time_zone_localizes(self): + timestamp = pd.Timestamp('2025-12-12 00:00:00') + self.assertIsNone(timestamp.tzinfo) + self.assertEqual( + control_loop.parse_timestamp(timestamp, self.time_zone), + pd.Timestamp('2025-12-12 00:00:00', tz=self.time_zone), + ) + + def test_parse_timestamp_with_different_time_zone_converts(self): + timestamp = pd.Timestamp('2025-12-12 00:00:00', tz='UTC') + self.assertIsNotNone(timestamp.tzinfo) + self.assertEqual( + control_loop.parse_timestamp(timestamp, self.time_zone), + pd.Timestamp('2025-12-11 16:00:00', tz=self.time_zone), + ) + + def test_parse_timestamp_with_same_time_zone_remains_the_same(self): + timestamp = pd.Timestamp('2025-12-12 00:00:00', tz=self.time_zone) + self.assertIsNotNone(timestamp.tzinfo) + self.assertEqual( + control_loop.parse_timestamp(timestamp, self.time_zone), + pd.Timestamp('2025-12-12 00:00:00', tz=self.time_zone), + ) + + +class MetricsWriterValidationTest(absltest.TestCase): + + def _create_loop( + self, writer: writer_lib.BaseWriter + ) -> control_loop.ControlLoop: + env = env_conftest.create_hybrid_action_environment( + # writer_factory=lambda metrics_path: writer, + default_actions=env_conftest.DEFAULT_HYBRID_ACTIONS, + ) + agent = default_agent.DefaultPolicyAgent(env=env) + env._metrics_writer = writer + return control_loop.ControlLoop(agent=agent) + + def test_metrics_writer_with_valid_interface(self): + writer = mock.create_autospec(writer_lib.BaseWriter, instance=True) + self.assertTrue(hasattr(writer, 'output_dir')) + self.assertTrue(hasattr(writer, 'write_json')) + + loop = self._create_loop(writer=writer) + self.assertEqual(loop.metrics_writer, writer) + + def test_writer_without_output_dir_raises_error(self): + writer = mock.create_autospec(writer_lib.BaseWriter, instance=True) + del writer.output_dir + + with self.assertRaisesRegex( + ValueError, 'Metrics writer does not have output_dir attribute.' + ): + self._create_loop(writer=writer) + + def test_writer_without_write_json_method_raises_error(self): + writer = mock.create_autospec(writer_lib.BaseWriter, instance=True) + del writer.write_json + + with self.assertRaisesRegex( + ValueError, 'Metrics writer does not have write_json method.' + ): + self._create_loop(writer=writer) + + +class LoopTest(absltest.TestCase): + """Tests for the setup of the control loop, before it has run.""" + + def setUp(self): + super().setUp() + self.loop = conftest.create_loop(max_steps=5) + + def test_initialization(self): + self.assertIsInstance(self.loop, control_loop.ControlLoop) + + def test_agent(self): + self.assertIsInstance(self.loop.agent, default_agent.DefaultPolicyAgent) + + def test_env(self): + self.assertIsInstance( + self.loop.env, hybrid_action_environment.HybridActionEnvironment + ) + + def test_attributes(self): + with self.subTest(name='max_steps'): + self.assertEqual(self.loop.max_steps, 5) + + with self.subTest(name='cum_reward'): + self.assertEqual(self.loop.cum_reward, 0.0) + + # ENVIRONMENT ATTRIBUTES + + def test_timestamps(self): + with self.subTest(name='start_timestamp'): + self.assertEqual( + self.loop.start_timestamp, + pd.Timestamp('2025-12-12 00:00:00', tz='US/Pacific'), + ) + + with self.subTest(name='end_timestamp'): + self.assertEqual( + self.loop.end_timestamp, + pd.Timestamp('2025-12-15 00:00:00', tz='US/Pacific'), + ) + + with self.subTest(name='current_local_timestamp'): + self.assertEqual( + self.loop.current_local_timestamp, + self.loop.env.current_local_timestamp, + ) + + def test_step_attributes(self): + with self.subTest(name='days_per_episode'): + self.assertEqual(self.loop.days_per_episode, 3) + + with self.subTest(name='time_step_interval'): + self.assertEqual(self.loop.time_step_interval, pd.Timedelta(minutes=5)) + + with self.subTest(name='steps_per_day'): + self.assertEqual(self.loop.steps_per_day, 288) + + with self.subTest(name='steps_per_episode'): + self.assertEqual(self.loop.steps_per_episode, 864) + + with self.subTest(name='episode_has_ended'): + self.assertFalse(self.loop.episode_has_ended) + + with self.subTest(name='current_step'): + self.assertEqual(self.loop.current_step, 0) + + # METRICS + + def test_metrics_output_dir(self): + self.assertEqual( + self.loop.metrics_output_dir, self.loop.metrics_writer.output_dir + ) + + def test_write_metadata(self): + self.loop.env.metrics_writer.reset_mock() + self.loop.write_metadata() + self.loop.env.metrics_writer.write_json.assert_called_once_with( + self.loop.json_metadata, 'metadata.json' + ) + + def test_write_results(self): + self.loop.env.metrics_writer.reset_mock() + with mock.patch.object( + control_loop, + 'get_clock_timestamp', + return_value=CLOCK_TIMESTAMP, + autospec=True, + ): + self.loop.write_results() + self.loop.env.metrics_writer.write_json.assert_called_once_with( + self.loop.json_results, 'results.json' + ) + + def test_json_metadata(self): + self.assertEqual( + self.loop.json_metadata, + { + 'start_timestamp': '2025-12-12 00:00:00-08:00', + 'end_timestamp': '2025-12-15 00:00:00-08:00', + 'days_per_episode': 3, + 'time_step_mins': 5, + 'steps_per_episode': 864, + 'env': self.loop.env.json_metadata, + 'agent': self.loop.agent.json_metadata, + }, + ) + + +class LoopResultsTest(absltest.TestCase): + """Tests for the results of the control loop, after it has run.""" + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.loop = conftest.create_loop(max_steps=5) + + # RUN THE LOOP (SO WE CAN TEST THE RESULTS AFTERWARDS) + + original_step_function = cls.loop.env.step + + def step_side_effect(*args, **kwargs): + time_step = original_step_function(*args, **kwargs) + return time_step._replace(reward=np.array([10.0])) + + with mock.patch.object( + cls.loop.env, 'step', side_effect=step_side_effect, autospec=True + ), mock.patch.object( + control_loop, + 'get_clock_timestamp', + autospec=True, + ) as mock_clock_timestamp: + mock_clock_timestamp.return_value = CLOCK_TIMESTAMP + cls.loop.run() + + def test_json_results(self): + with mock.patch.object( + control_loop, + 'get_clock_timestamp', + return_value=CLOCK_TIMESTAMP, + autospec=True, + ): + self.assertEqual( + self.loop.json_results, + { + 'clock_timestamp': '2026-03-26 12:00:00', + 'current_timestamp': '2025-12-12 00:25:00-08:00', + 'current_step': 5, + 'cum_reward': 50.0, + 'results': [], + }, + ) + + +class LoopEndsWhenEpisodeEndsTest(absltest.TestCase): + """Tests that the loop stops when episode has ended.""" + + def test_stops_when_episode_has_ended(self): + loop = conftest.create_loop(max_steps=None) + with mock.patch.object( + control_loop.ControlLoop, + 'episode_has_ended', + new_callable=mock.PropertyMock, + side_effect=[False, False, True], + ) as mock_ended: + loop.run() + + self.assertEqual(mock_ended.call_count, 3) + self.assertEqual(loop.current_step, 2) + + +class ActionRejectionTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.loop = conftest.create_loop(max_steps=1) + + def test_action_rejection_reward(self): + self.assertEqual(control_loop.ACTION_REJECTION_REWARD, -np.inf) + + def test_action_rejected_returns_true_when_reward_is_neg_inf(self): + time_step = ts.TimeStep( + step_type=ts.StepType.MID, + reward=np.array([control_loop.ACTION_REJECTION_REWARD]), + discount=np.array(1.0), + observation=(), + ) + self.assertTrue(self.loop._action_rejected(time_step)) + + def test_action_rejected_returns_false_when_reward_is_not_neg_inf(self): + self.assertFalse(self.loop._action_rejected(EXAMPLE_TIME_STEP)) + + +class IntervalTest(absltest.TestCase): + + def test_validity_interval(self): + loop = conftest.create_loop(max_steps=5) + action_ctx = mock.Mock() + action_ctx.validity_interval = 10 # minutes + action_ctx.get_action.return_value = env_conftest.DEFAULT_HYBRID_ACTIONS + + # All this mocking and patching helps the environment step very fast, to + # drastically reduce the time it takes to run this test. + def step_side_effect(*args, **kwargs): + del args, kwargs # Unused. + loop.env._step_count += 1 + return EXAMPLE_TIME_STEP + + with mock.patch.object( + loop.agent, + 'get_action_context', + return_value=action_ctx, + autospec=True, + ) as mock_get_action_context: + with mock.patch.object( + loop.env, + 'step', + side_effect=step_side_effect, + autospec=True, + ) as mock_step: + with mock.patch.object( + loop.env, + 'get_observation_response', + return_value=mock.Mock(), + autospec=True, + ): + with mock.patch.object( + loop.env, + 'get_reward_info_and_response', + return_value=(mock.Mock(), mock.Mock()), + autospec=True, + ): + loop.run() + + # The agent provides an initial action before the first step. + # The environment is stepped five times, once every five minutes, for a + # total duration of 25 minutes. Because the validity interval is 10 minutes, + # the agent is only asked to get an action twice more during this time (for + # a total of three actions). + self.assertEqual(mock_step.call_count, 5) + self.assertEqual(mock_get_action_context.call_count, 3) + + +if __name__ == '__main__': + absltest.main() diff --git a/smart_control/llm/prompts/base_promptmaker.py b/smart_control/llm/prompts/base_promptmaker.py new file mode 100644 index 00000000..3db1ac87 --- /dev/null +++ b/smart_control/llm/prompts/base_promptmaker.py @@ -0,0 +1,142 @@ +"""Base class for promptmakers. + +Promptmakers are responsible for compiling a prompt for an LLM. + +Promptmakers are set up to combine a user-provided 'base prompt' with output +formatting instructions automatically derived from a Pydantic model, to arrive +at the final 'prompt' which gets sent to the LLM. + +This base class can be flexibility used with any Pydantic model, but child +classes will use a specific Pydantic model suited for building control. +""" + +import abc +import textwrap +from typing import Any, Callable + +import langchain.output_parsers +import pydantic + +PydanticOutputParser = langchain.output_parsers.PydanticOutputParser + +SerializableData = dict[str, Any] + +DedentFunction = Callable[[str], str] + + +def full_dedent(txt: str) -> str: + """Removes all leading whitespace from each line in a string. + + While textwrap.dedent is designed to preserve the relative indentation within + a block of text, this function removes all leading whitespace from each line, + regardless of the relative indentation. + + This behavior is helpful when you want to define a prompt as a multiline + string inside a function or method, and you want to ensure all lines in the + resulting prompt are left-justified, ignoring any indentation used for + readability in the source code. + + This is also relevant when a prompt is dynamically compiled using + multiple sections, including nested sub-sections that are defined in their own + methods in the promptmaker class. This behavior can prevent stacking of + relative indentation from nested blocks of code. + + If you have a markdown multi-level list, you would want to use + textwrap.dedent instead, to preserve the relative indentation of the list. + + Args: + txt: The string to remove leading whitespace from. + + Returns: + The string with all leading whitespace removed. + """ + return '\n'.join(line.lstrip() for line in txt.strip().splitlines()) + + +class BasePromptmaker(abc.ABC): + """Base Promptmaker. + + A Promptmaker is responsible for compiling a prompt for an LLM. + + The Promptmaker uses a Pydantic model to provide formatting instructions that + result in the LLM producing reliable JSON formatted string responses. + + You override the `base_prompt` property to provide the main prompt, and the + `output_schema_class` argument to specify the Pydantic model used to provide + formatting instructions. Then the promptmaker combines your base prompt with + formatting instructions in the final `prompt` property, which you can send to + the LLM. + """ + + def __init__( + self, + output_schema_class: type[pydantic.BaseModel], + dedent: DedentFunction = textwrap.dedent, + ): + """Initializes the instance. + + Args: + output_schema_class: The pydantic model class used to provide JSON + response formatting instructions in the prompt. + dedent: The function used to remove leading whitespace from the prompt. + """ + self.output_schema_class = output_schema_class + self.dedent = dedent + + @property + @abc.abstractmethod + def base_prompt(self) -> str: + """The main prompt, fully hydrated with data as necessary. + + The `base_prompt` does not include formatting instructions, as they are + automatically added in the `prompt` property. + """ + + @property + def prompt(self) -> str: + """The final prompt, including response formatting instructions.""" + return self.dedent( + '\n\n'.join(( + self.base_prompt, + self.formatting_instructions_section, + )) + ) + + @property + def formatting_instructions_section(self) -> str: + """The section of the prompt containing formatting instructions.""" + return '\n'.join([ + '## Formatting Instructions\n', + ( + 'IMPORTANT: The output MUST be a single, valid JSON object ' + 'conforming to the schema below.' + ), + ( + 'Do NOT include any other text, explanations, pleasantries, or any ' + 'other content before or after the JSON object.' + ), + self.formatting_instructions, + ]) + + @property + def formatting_instructions(self) -> str: + """Formatting instructions for the desired LLM output structure.""" + return self.output_parser.get_format_instructions() + + @property + def output_parser(self) -> PydanticOutputParser: + """A parser that derives formatting instructions from a pydantic model.""" + return PydanticOutputParser(pydantic_object=self.output_schema_class) + + @property + def output_schema(self) -> dict[str, Any]: + """The JSON schema for the output.""" + return self.output_schema_class.model_json_schema() + + @property + def json_metadata(self) -> SerializableData: + """Metadata about the promptmaker, suitable for JSON serialization.""" + return { + 'type': self.__class__.__name__, + 'output_schema_class': self.output_schema_class.__name__, + } diff --git a/smart_control/llm/prompts/base_promptmaker_test.py b/smart_control/llm/prompts/base_promptmaker_test.py new file mode 100644 index 00000000..2acf09ac --- /dev/null +++ b/smart_control/llm/prompts/base_promptmaker_test.py @@ -0,0 +1,189 @@ +import json +import textwrap +from typing import Callable + +from absl.testing import absltest +import immutabledict +import langchain +import pydantic +from smart_buildings.smart_control.llm.prompts import base_promptmaker +from smart_buildings.smart_control.llm.schema import conftest as schema_conftest + +BASE_PROMPT = "What year was America founded?" + +EXPECTED_OUTPUT_SCHEMA = immutabledict.immutabledict({ + "title": "ExampleOutputSchema", + "description": ( + "Simple example implementation of an output schema, for testing" + " purposes." + ), + "type": "object", + "properties": { + "year": { + "description": "The year, as an integer.", + "title": "Year", + "type": "integer", + }, + "explanation": { + "description": "The reasoning behind choosing this specific year.", + "title": "Explanation", + "type": "string", + }, + }, + "required": ["year", "explanation"], +}) + + +class ExampleOutputSchema(pydantic.BaseModel): + """Simple example implementation of an output schema, for testing purposes.""" + + year: int = pydantic.Field(description="The year, as an integer.") + + explanation: str = pydantic.Field( + description="The reasoning behind choosing this specific year." + ) + + +class ExamplePromptmaker(base_promptmaker.BasePromptmaker): + """Simple example implementation of BasePromptmaker, for testing purposes.""" + + def __init__(self, dedent: Callable[[str], str] = textwrap.dedent): + super().__init__(output_schema_class=ExampleOutputSchema, dedent=dedent) + + @property + def base_prompt(self) -> str: + return BASE_PROMPT + + +# +# TESTS +# + + +class DedentTest(absltest.TestCase): + """Tests to contrast different dedentation behavior.""" + + def setUp(self): + super().setUp() + self.base_prompt = """\ + Hello world! + Hello world! + """ + + def test_no_dedent_leaves_leading_whitespace(self): + pm = ExamplePromptmaker(dedent=lambda txt: txt) + self.assertEqual( + pm.dedent(self.base_prompt), + " Hello world!\n Hello world!\n ", + ) + + def test_textwrap_dedent_leaves_leading_relative_whitespace(self): + pm = ExamplePromptmaker(dedent=textwrap.dedent) + self.assertEqual( + pm.dedent(self.base_prompt), + "Hello world!\n Hello world!\n", + ) + + def test_full_dedent_removes_all_leading_whitespace(self): + pm = ExamplePromptmaker(dedent=base_promptmaker.full_dedent) + self.assertEqual( + pm.dedent(self.base_prompt), + "Hello world!\nHello world!", + ) + + +class BasePromptmakerTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.promptmaker = ExamplePromptmaker() + + def test_initialization(self): + self.assertIsInstance(self.promptmaker, base_promptmaker.BasePromptmaker) + self.assertEqual(self.promptmaker.output_schema_class, ExampleOutputSchema) + + def test_base_prompt(self): + self.assertEqual(self.promptmaker.base_prompt, BASE_PROMPT) + + def test_prompt(self): + self.assertEqual( + self.promptmaker.prompt, + f"{BASE_PROMPT}\n\n{self.promptmaker.formatting_instructions_section}", + ) + + def test_formatting_instructions_section(self): + self.assertEqual( + self.promptmaker.formatting_instructions_section, + ( + "## Formatting Instructions\n\n" + "IMPORTANT: The output MUST be a single, valid JSON object " + "conforming to the schema below.\n" + "Do NOT include any other text, explanations, pleasantries, or " + "any other content before or after the JSON object.\n" + f"{self.promptmaker.formatting_instructions}" + ), + ) + + def test_formatting_instructions(self): + instructions = self.promptmaker.formatting_instructions + self.assertIsInstance(instructions, str) + + parsed_schema = schema_conftest.parse_instructions_schema(instructions) + expected_schema = dict(EXPECTED_OUTPUT_SCHEMA) # a shallow copy + del expected_schema["title"] + del expected_schema["type"] + + with self.subTest(name="introduces_the_schema"): + self.assertStartsWith( + instructions, + "The output should be formatted as a JSON instance that conforms" + " to the JSON schema below.", + ) + + with self.subTest(name="provides_an_example_schema"): + self.assertIn( + ( + 'As an example, for the schema {"properties": {"foo": {"title":' + ' "Foo", "description": "a list of strings", "type": "array",' + ' "items": {"type": "string"}}}, "required": ["foo"]}\nthe' + ' object {"foo": ["bar", "baz"]} is a well-formatted instance of' + ' the schema. The object {"properties": {"foo": ["bar", "baz"]}}' + " is not well-formatted." + ), + instructions, + ) + + with self.subTest(name="provides_output_schema"): + self.assertEqual(parsed_schema, expected_schema) + self.assertEndsWith( + instructions, + "Here is the output schema:\n```\n" + + json.dumps(expected_schema) + + "\n```", + ) + + def test_output_parser(self): + self.assertIsInstance( + self.promptmaker.output_parser, + langchain.output_parsers.PydanticOutputParser, + ) + self.assertEqual( + self.promptmaker.output_parser.pydantic_object, + ExampleOutputSchema, + ) + + def test_output_schema(self): + self.assertEqual(self.promptmaker.output_schema, EXPECTED_OUTPUT_SCHEMA) + + def test_json_metadata(self): + self.assertEqual( + self.promptmaker.json_metadata, + { + "type": "ExamplePromptmaker", + "output_schema_class": "ExampleOutputSchema", + }, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/smart_control/llm/prompts/floor_based_promptmaker.py b/smart_control/llm/prompts/floor_based_promptmaker.py new file mode 100644 index 00000000..3108ac36 --- /dev/null +++ b/smart_control/llm/prompts/floor_based_promptmaker.py @@ -0,0 +1,35 @@ +"""Promptmaker class with floor-specific zone comfort info.""" + +import functools + +import pandas as pd +from smart_buildings.smart_control.llm.prompts import promptmaker as pm + + +class FloorBasedPromptmaker(pm.Promptmaker): + """Updated promptmaker class, with floor-specific zone comfort info.""" + + @functools.cached_property + def zone_conditions_histogram_by_floor(self) -> pd.DataFrame: + """A histogram of zone conditions by floor.""" + return self.reward_info_parser.get_zone_conditions_histogram_by_floor( + zones=self.env.building.zones + ).T + + @property + def zone_conditions_subsection(self) -> str: + """A section describing the current conditions in the building.""" + + return self.dedent(f""" + ### Current Zone Temperatures + + The table below conveys the comfort conditions across all zones in the building, by floor: + + {self.zone_conditions_histogram_by_floor.to_markdown(index=True)} + + The row 'occupancy_count' shows the total number of occupants building-wide at a specific temperature. + The row 'setpoint_mask' indicates with a '0' if the temperature is within comfort range, a '-1' if the temperature is too cold, and a '1' if the temperature is too hot. + The row 'setpoint_range' indicates with '+' if the temperature is inside the acceptable range, and '-' if it is outside. + The row 'exposed_count' indicates the count of occupants being exposed to unacceptable comfort conditions. + The rows starting with 'occ@floor' show the normalized distribution of zone counts for each floor at that temperature. + """) diff --git a/smart_control/llm/prompts/floor_based_promptmaker_test.py b/smart_control/llm/prompts/floor_based_promptmaker_test.py new file mode 100644 index 00000000..f72ee7d0 --- /dev/null +++ b/smart_control/llm/prompts/floor_based_promptmaker_test.py @@ -0,0 +1,94 @@ +from absl.testing import absltest +from smart_buildings.smart_control.environment import conftest as env_conftest +from smart_buildings.smart_control.llm.prompts import floor_based_promptmaker +from smart_buildings.smart_control.utils import temperature_conversion as tc + + +class FloorBasedPromptmakerTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.zone_reward_configs = { + 'zone_1': { + 'zone_air_temperature': 292.1, + 'heating_setpoint_temperature': 294.0, + 'cooling_setpoint_temperature': 296.0, + 'average_occupancy': 5.0, + }, + 'zone_2': { + 'zone_air_temperature': 296.2, + 'heating_setpoint_temperature': 294.0, + 'cooling_setpoint_temperature': 296.0, + 'average_occupancy': 10.0, + }, + 'zone_3': { + 'zone_air_temperature': 297.9, + 'heating_setpoint_temperature': 294.0, + 'cooling_setpoint_temperature': 296.0, + 'average_occupancy': 3.0, + }, + } + self.env = env_conftest.create_hybrid_action_environment( + layout=env_conftest.MULTI_FLOOR_LAYOUT, + zone_reward_configs=self.zone_reward_configs, + ) + self.pm = floor_based_promptmaker.FloorBasedPromptmaker( + env=self.env, temp_display_unit=tc.TempUnit.KELVIN + ) + + def test_zone_conditions_histogram_by_floor(self): + df = self.pm.zone_conditions_histogram_by_floor + # The histogram is transposed in FloorBasedPromptmaker. + # Index should include occupancy_count, setpoint_range, exposed_count, + # and floor distribution(s). + self.assertIn('occupancy_count', df.index) + self.assertIn('setpoint_range', df.index) + self.assertIn('exposed_count', df.index) + + floor_rows = [i for i in df.index if str(i).startswith('occ@floor')] + self.assertCountEqual(floor_rows, ['occ@floor1', 'occ@floor2']) + + # Global occupancy: 5 at 292, 10 at 296, 3 at 298. + self.assertEqual(df.loc['occupancy_count', 292.0], 5) + self.assertEqual(df.loc['occupancy_count', 296.0], 10) + self.assertEqual(df.loc['occupancy_count', 298.0], 3) + + # Floor 1 distribution: zone_1 (temp 292) and zone_2 (temp 296). + # Since they are normalized, each should be 0.5 at their respective bins. + self.assertEqual(df.loc['occ@floor1', 292.0], 0.5) + self.assertEqual(df.loc['occ@floor1', 296.0], 0.5) + + # Floor 2 distribution: zone_3 (temp 298). + self.assertEqual(df.loc['occ@floor2', 298.0], 1.0) + + def test_zone_conditions_histogram_by_floor_is_always_kelvin(self): + # Setup promptmaker with Fahrenheit as display unit + pm = floor_based_promptmaker.FloorBasedPromptmaker( + env=self.env, + temp_display_unit=tc.TempUnit.FAHRENHEIT, + ) + + df = pm.zone_conditions_histogram_by_floor + + # Even though display unit is F, the table data passed to LLM stays in K. + # Global occupancy: 5 at 292, 10 at 296, 3 at 298. + self.assertEqual(df.loc['occupancy_count', 292.0], 5) + self.assertEqual(df.loc['occupancy_count', 296.0], 10) + self.assertEqual(df.loc['occupancy_count', 298.0], 3) + + # Verify the prompt text mentions Fahrenheit + self.assertIn('communicate temperatures in Fahrenheit', pm.base_prompt) + + def test_current_conditions_section(self): + section = self.pm.current_conditions_section + self.assertIn('## Current Conditions', section) + self.assertIn('### Current Zone Temperatures', section) + self.assertIn('by floor:', section) + self.assertIn("The rows starting with 'occ@floor'", section) + + table = self.pm.zone_conditions_histogram_by_floor.to_markdown(index=True) + self.assertIn(table, section) + + +if __name__ == '__main__': + absltest.main() diff --git a/smart_control/llm/prompts/generator.py b/smart_control/llm/prompts/generator.py new file mode 100644 index 00000000..81d95c6d --- /dev/null +++ b/smart_control/llm/prompts/generator.py @@ -0,0 +1,50 @@ +"""Utilities for generating example prompts. + +Creates an example prompt and writes it to a markdown file in the "examples" +directory. This helps facilitate developer reviews of the prompt. Once written, +you can use the text editor's markdown preview functionality to view the prompt +and verify the formatting renders correctly. +""" + +import os +from typing import Type + +from absl import logging +from smart_buildings.smart_control.environment import hybrid_action_environment +from smart_buildings.smart_control.llm.prompts import promptmaker + + +def write_prompt_md( + promptmaker_class: Type[promptmaker.Promptmaker], + include_weights: bool, + dirpath: str, + filename: str, +) -> None: + """Generates an example prompt and writes it to a markdown file. + + Args: + promptmaker_class: The promptmaker class to use. + include_weights: Whether to include weights in the prompt. + dirpath: The directory to write the markdown file to. + filename: The name of the markdown file to write. + """ + + logging.info("LOADING ENVIRONMENT...") + env = hybrid_action_environment.HybridActionEnvironment() + logging.info("Current local timestamp: %s", env.current_local_timestamp) + env.reset() + + logging.info("CREATING PROMPTMAKER: %s...", promptmaker_class.__name__) + pm = promptmaker_class(env=env, include_weights=include_weights) + + logging.info("SETTING UP EXAMPLE PROMPTS DIRECTORY...") + examples_dirpath = os.path.join(dirpath, "examples") + os.makedirs(examples_dirpath, exist_ok=True) + + logging.info("WRITING PROMPT TO %s...", filename) + md_filepath = os.path.join(examples_dirpath, filename) + with open(md_filepath, "w") as f: + f.write(pm.prompt) + f.write("\n") + + logging.info("DONE") diff --git a/smart_control/llm/prompts/promptmaker.py b/smart_control/llm/prompts/promptmaker.py new file mode 100644 index 00000000..3c832f2f --- /dev/null +++ b/smart_control/llm/prompts/promptmaker.py @@ -0,0 +1,555 @@ +"""Promptmaker for optimal control of HVAC systems in smart buildings. + +This promptmaker extends the base promptmaker class to create a prompt for +controlling HVAC systems in smart buildings. + +It uses the SetpointsAction pydantic model to provide formatting instructions +for the LLM response, to include a validity interval, overall strategy, a list +of setpoints and corresponding setpoint-specific justifications. + +This promptmaker constructs a basic non-opinionated prompt that could be used +as a basis for more specialized child classes. Prompts are expected to be an +active area of experimentation, so this class is designed to support +extensibility. + +The promptmaker uses a number of 'sections' that comprise the prompt. Each +section is a piece of the prompt that serves a specific purpose. By inheriting +from the promptmaker class, you can override specific sections to customize +the prompt without having to rewrite the entire prompt. + +In terms of content formatting, we are using Markdown. Research suggests this +may help the LLM better understand the organizational structure of the content. +See 'Does Prompt Formatting Have Any Impact on LLM Performance?' by He, et al. + +We are also using new-line characters to separate each sentence, keeping each +sentence fully contained on the same line. There is research to suggest that +new-line characters are effective delimiters for helping the LLM understand the +content (specifically examples) and generate a better response. See: 'A single +character can make or break your LLM evals' by Jingtong Su, et al. +""" + +import dataclasses +from typing import Any, Callable, Final + +import pandas as pd +from smart_buildings.smart_control.environment import environment +from smart_buildings.smart_control.environment import hybrid_action_environment as hybrid_env +from smart_buildings.smart_control.llm.prompts import base_promptmaker +from smart_buildings.smart_control.llm.schema import output_schema +from smart_buildings.smart_control.proto import smart_control_building_pb2 as building_pb2 +from smart_buildings.smart_control.proto import smart_control_reward_pb2 as reward_pb2 +from smart_buildings.smart_control.utils import temperature_conversion as tc +from smart_buildings.smart_control.utils.proto_parsers import observation_response_parser as or_parser +from smart_buildings.smart_control.utils.proto_parsers import reward_info_parser as ri_parser + +SerializableData = dict[str, Any] + +# TODO(mjrossetti): Consider importing these constants from other more central +# locations related to the devices, once they are available there. +AHU_STATIC_PRESSURE_SETPOINT: Final[str] = "static_pressure_setpoint" +AHU_SUPPLY_AIR_TEMPERATURE_SETPOINT: Final[str] = "supply_air_temperature_setpoint" # pylint: disable=line-too-long +HWS_DIFFERENTIAL_PRESSURE_SETPOINT: Final[str] = "differential_pressure_setpoint" # pylint: disable=line-too-long +HWS_SUPPLY_WATER_TEMPERATURE_SETPOINT: Final[str] = "supply_water_setpoint" + + +@dataclasses.dataclass +class BuildingInfo: + """Information about the building under control. + + This information is provided to the LLM to give it context about the building. + + Attributes: + stories: The number of stories in the building. + sqft: The square footage of the building. + location: The location of the building. + name: The name of the building, if applicable. + """ + + name: str = "SB-1" + stories: str = "two" + sqft: int = 96_000 + location: str = "Mountain View, California" + + +class Promptmaker(base_promptmaker.BasePromptmaker): + """Promptmaker for building control. + + This specific promptmaker assumes you are using a HybridActionEnvironment. + """ + + def __init__( + self, + env: environment.Environment, + *, + observation_response: building_pb2.ObservationResponse | None = None, + reward_info: reward_pb2.RewardInfo | None = None, + building_info: BuildingInfo | None = None, + output_schema_class: ( + type[output_schema.SetpointsAction] | None + ) = output_schema.SetpointsAction, + dedent: Callable[[str], str] = base_promptmaker.full_dedent, + include_weights: bool = False, + occupancy_mode_min: int = 10, + temp_display_unit: tc.TempUnit | str = tc.TempUnit.FAHRENHEIT, + lazy_init_protos: bool = False, + ): + """Initializes the instance. + + Args: + env: The environment containing information about the building, + observation space, action space, reward function, etc. + observation_response: The observation response from the environment. If + None, the observation response will be retrieved from the environment. + reward_info: The reward info from the environment. If None, the reward + info will be retrieved from the environment. + building_info: Information about the building being controlled, such as + the number of stories, square footage, and location. + output_schema_class: The pydantic model class used to provide JSON + response formatting instructions in the prompt. Uses the pre-configured + `SetpointsAction` model by default. To use custom validity interval + options, construct a custom output schema class using the + `output_schema.create_action_model` function, and pass that class here. + dedent: The function used to remove leading whitespace from the prompt. + Uses the `full_dedent` function by default, because otherwise the + inserted tables seem to be aligned to the left of the rest of the + content. + include_weights: Whether to include the reward function weights in the + prompt. + occupancy_mode_min: The minimum number of occupants in the building to + be considered in occupancy mode. + temp_display_unit: The temperature unit to be used by the LLM in its + justifications and reasoning. All input temperatures are in Kelvin. + lazy_init_protos: Whether to lazily setup the observation + response and reward info. If False, (by default), the protos + should be passed in during initialization, or will automatically be set, + for convenience. If True, the protos are expected to be passed in after + initialization, using the `set_protos` method. + """ + super().__init__(output_schema_class=output_schema_class, dedent=dedent) + self.env = env + self.include_weights = include_weights + self.occupancy_mode_min = occupancy_mode_min + self.temp_display_unit = tc.assign_temp_unit(temp_display_unit) + self.building_info = building_info or BuildingInfo() + self.lazy_init_protos = lazy_init_protos + self._observation_response_parser: ( + or_parser.ObservationResponseParser | None + ) = None + self._reward_info_parser: ri_parser.RewardInfoParser | None = None + + if not self.lazy_init_protos: + self.set_protos( + observation_response=observation_response, + reward_info=reward_info, + ) + + def set_protos( + self, + observation_response: building_pb2.ObservationResponse | None, + reward_info: reward_pb2.RewardInfo | None, + ) -> None: + """Sets up the observation response and reward info parsers. + + If you lazy initialized the protos, you must call this method to set them. + + Args: + observation_response: The observation response from the environment. If + None, the observation response will be retrieved from the environment. + reward_info: The reward info from the environment. If None, the reward + info will be retrieved from the environment. + """ + self._observation_response_parser = self._setup_observation_response( + observation_response=observation_response, + ) + self._reward_info_parser = self._setup_reward_info(reward_info=reward_info) + + def _setup_observation_response( + self, + observation_response: building_pb2.ObservationResponse | None = None, + ) -> or_parser.ObservationResponseParser: + """Returns an observation response parser. + + Args: + observation_response: The observation response from the environment. If + None, the observation response will be retrieved from the environment. + + Returns: + An observation response parser. + """ + if observation_response is None: + observation_response = self.env.get_observation_response() + + return or_parser.ObservationResponseParser( + observation_response=observation_response + ) + + def _setup_reward_info( + self, reward_info: reward_pb2.RewardInfo | None = None + ) -> ri_parser.RewardInfoParser: + """Returns a reward info parser. + + Args: + reward_info: The reward info from the environment. If None, the reward + info will be retrieved from the environment. + + Returns: + A reward info parser. + """ + if reward_info is None: + reward_info = self.env.get_reward_info() + + return ri_parser.RewardInfoParser(reward_info=reward_info) + + @property + def observation_response_parser(self) -> or_parser.ObservationResponseParser: + """The observation response parser. Assumed to have been set up already.""" + if self._observation_response_parser is None: + raise ValueError("Observation response parser is None.") + return self._observation_response_parser + + @property + def reward_info_parser(self) -> ri_parser.RewardInfoParser: + """The reward info parser. Assumed to have been set up already.""" + if self._reward_info_parser is None: + raise ValueError("Reward info parser is None.") + return self._reward_info_parser + + # DATA AND PROPERTIES + + @property + def json_metadata(self) -> SerializableData: + """Info to write into a JSON file. Needs to be serializable.""" + return super().json_metadata | { + "include_weights": self.include_weights, + "occupancy_mode_min": self.occupancy_mode_min, + "temp_display_unit": self.temp_display_unit.value, + "building_info": dataclasses.asdict(self.building_info), + } + + @property + def building_info_series(self) -> pd.Series: + """A pandas.Series describing the building information.""" + return pd.Series( + dataclasses.asdict(self.building_info), name="building_info" + ) + + @property + def setpoints_df(self) -> pd.DataFrame: + """A dataframe describing the devices and setpoints under control. + + Includes information about the range of possible native values for each + setpoint. + + The LLM will use the device_id and setpoint_name values as a composite key + to uniquely identify setpoints in its responses. + + Returns: + A dataframe describing the devices and setpoints under control. + """ + df = self.env.action_fields_df[[ + "device_id", + "setpoint_name", + "setpoint_type", + "units", + "min_native_value", + "max_native_value", + ]].copy() + return df.sort_values(by=["device_id", "setpoint_name"]).reset_index( + drop=True + ) + + @property + def weights(self) -> dict[str, float] | None: + """Returns the reward function weights, if available.""" + if hasattr(self.env.reward_function, "weights"): + weights = self.env.reward_function.weights.copy() + # Rename "productivity_weight" to "comfort_weight": + if "productivity_weight" in weights: + weights["comfort_weight"] = weights.pop("productivity_weight") + return weights + return None + + @property + def weights_series(self) -> pd.Series | None: + """A pandas.Series describing the reward function weights, if available.""" + if self.weights is not None: + return pd.Series(self.weights, name="weight") + + @property + def validity_intervals(self) -> list[int]: + """A list of validity intervals (in minutes) for the LLM to choose from.""" + return self.output_schema["properties"]["validity_interval"]["enum"] + + # PROMPT CONTENT + + @property + def base_prompt(self) -> str: + """The base prompt, excluding formatting instructions.""" + return "\n\n".join([ + "# Agent Instructions", + self.objectives_section, + self.zone_info_section, + self.occupancy_modes_section, + self.hvac_system_guidelines_section, + self.action_guidelines_section, + self.current_conditions_section, + self.current_action_section, + ]) + + @property + def objectives_section(self) -> str: + """A section describing the LLM's role and objectives. + + Includes the reward function weights, if available and enabled via the + `include_weights` argument. + + Returns: + A section describing the LLM's role and objectives. + """ + + section = self.dedent(f""" + ## Objectives + + ### Role + + You are a skilled, experienced, and innovative operator of a commercial office building. + You possess in-depth and complete knowledge about HVAC systems, as well as ASHRAE standards and certifications. + Your job is to optimally control HVAC devices in a given commercial office building. + + **Building Information**: + + {self.building_info_series.to_markdown(index=True)} + + ### Overall Goal + + As the building operator, your **Optimal Control Objectives** are to: + + + Minimize energy consumption / costs, and + + Minimize carbon emissions, and + + Maintain occupant comfort (a.k.a. productivity) + + This is a multi-objective optimization problem, where you must balance competing objectives. + """) + + weights_series = self.weights_series + if self.include_weights and weights_series is not None: + section += "\n\n" + self.dedent(f""" + ### Reward Function Weights + + We have assigned a weight to designate the importance of each objective. + Your job is to maximize the weighted sum of the objectives, placing a higher priority on objectives with greater weights. + The weights are designated in the table below: + + {weights_series.to_markdown(index=True)} + """) + + return self.dedent(section) + + @property + def zone_info_section(self) -> str: + """A section describing zone related terminology.""" + + return self.dedent(""" + ## Zone Information + + A **zone** is a room, or space in the office building that is potentially occupied by humans, and must be conditioned for comfort when occupied. + + ### Zone Comfort + + The **zone air temperature** is the average temperature in a zone and the measure of comfort in the zone. + + The **zone air heating setpoint** is the minimum temperature that zone is allowed to be, without actively heating the zone. + It's like the minimum of the occupant comfort range. + The **zone air cooling setpoint** is the maximum temperature that zone is allowed to be, without actively cooling the zone. + It's like the maximum of the occupant comfort range. + The zone air heating temperature setpoint is always below the zone air cooling temperature setpoint. + + Ideally: `zone air heating setpoint < zone air temperature if occupied < zone air cooling setpoint` + """) + + @property + def occupancy_modes_section(self) -> str: + """A section describing and contrasting the different occupancy modes.""" + + # TODO(mjrossetti): Add a table of hourly occupancy trends, for each day of + # the week. + + return self.dedent(f""" + ## Occupancy Modes + + You should operate the building in an occupancy mode and an efficiency mode. + + **Occupancy mode** is when the building has at least {self.occupancy_mode_min} occupants. + When in occupancy mode, you should try to maintain zone air temperatures within comfort range (for all occupied zones), while also minimizing energy consumption and carbon emissions. + + **Efficiency mode** is when the building has fewer than {self.occupancy_mode_min} occupants. + When in efficiency mode, your only objective should be to SIGNIFICANTLY reduce energy consumption and carbon emissions. + + ### Heating and Cooling Guidelines + + To save energy, you should transition from efficiency mode to occupancy mode in the morning as late as possible, but early enough to ensure the building is in setpoints when the occupants arrive. + Depending on the outside air temperature, the building will take some time to get into setpoint ranges, especially in the mornings before transitioning from efficiency mode to occupancy mode. + Therefore, you must apply heating or cooling early enough to ensure that the setpoint temperatures are met before occupancy mode setpoints are applied. + + Time it takes to increase zone air temperature by 1 degree Fahrenheit: + + + Under standard conditions with lower outside air temperature, and active heating, it takes 10 minutes. + + Under standard conditions with higher outside air temperature, and no active cooling, it takes 20 minutes. + + Time it takes to decrease zone air temperature by 1 degree Fahrenheit: + + + Under standard conditions with higher outside air temperature, and active cooling, it takes 10 minutes. + + Under standard conditions with lower outside air temperature, and with no active heating, it takes 20 minutes. + """) + + @property + def hvac_system_guidelines_section(self) -> str: + """A section describing building-specific HVAC system setup and guidelines. + + This section describes the HVAC devices under control, and provides + guidance for controlling them. + """ + + return self.dedent(f""" + ## HVAC System Control Guidelines + + There are two systems under your control, with three devices total. + The Air Handler System (AHS) includes two air handler / air conditioner devices (AC-1 and AC-2). + The Hot Water System (HWS) includes one boiler device (BLR). + + ### Devices and Setpoints + + **AC-1**: Air Conditioner / Air Handler Unit (for all zones on the first floor) + + * '{hybrid_env.DISCRETE_ACTION_COMMAND}': you can turn the device ON (1) and OFF (0) + * '{AHU_STATIC_PRESSURE_SETPOINT}': you can increase/decrease airflow by increasing/decreasing static pressure + * '{AHU_SUPPLY_AIR_TEMPERATURE_SETPOINT}': you can cool the zones by lowering the supply air temperature + + **AC-2**: Air Conditioner / Air Handler Unit (for all zones on the second floor) + + * '{hybrid_env.DISCRETE_ACTION_COMMAND}': you can turn the device ON (1) and OFF (0) + * '{AHU_STATIC_PRESSURE_SETPOINT}': you can increase/decrease airflow by increasing/decreasing static pressure + * '{AHU_SUPPLY_AIR_TEMPERATURE_SETPOINT}': you can cool the zones by lowering the supply air temperature + + **BLR**: Boiler (for both floors): + + * '{hybrid_env.DISCRETE_ACTION_COMMAND}': you can turn the device ON (1) and OFF (0) + * '{HWS_DIFFERENTIAL_PRESSURE_SETPOINT}': you can increase/decrease water flow to the zones by increasing/decreasing differential pressure + * '{HWS_SUPPLY_WATER_TEMPERATURE_SETPOINT}': you can heat the zones by increasing the water supply temperature + + ### Air Conditioner (AC) / Air Handler (AHU) Guidelines + + Turning on an AC will consume electricity by running the air blowers and running the refrigeration compressors. + Turning them off will not consume any electricity, but will also remove air cooling and ventilation. + + Lowering an AC's supply air temperature below outside air temperature will cause the compressor to run, consuming electricity, and will cool the zones. + Setting the supply air temperature only enables you to cool, but not heat the zones. + + Increasing an AC's static pressure will increase air circulation through the zones, which results in cooling or heating the zones. + + ### Boiler (BLR) Guidelines + + Lowering the boiler's supply water temperature will reduce carbon emission, but will also reduce the ability to heat zones. + + ### Zone Temperature Control Guidelines + + If a zone is occupied and the zone air temperature is below the zone air heating temperature setpoint, the VAV in the zone will request air flow and hot water circulation to heat the zone. + You control air flow by managing the AHU static pressure setpoints, and hot water circulation by managing the HWS differential pressure and supply water temperature setpoints. + + If the zone is occupied and the zone air temperature is above the zone air cooling temperature setpoint, the VAV in the zone will request cool air from the AHU. + You control the amount of cooling by managing the AHU static pressure and supply air temperature setpoints. + """) + + @property + def action_guidelines_section(self) -> str: + """A section describing the action space.""" + + return self.dedent(f""" + ## Action Guidelines + + Throughout the day, you will be prompted to choose your actions. + Your actions will be used to control the HVAC systems in the building. + An action requires a value and justification for each of the device setpoints listed below. + + {self.setpoints_df.to_markdown(index=False)} + + Note about temperature units: + All temperatures will be reported to you in Kelvin. + The temperatures you choose to set should be in Kelvin. + However, in your textual responses and justifications only, + you should communicate temperatures in {self.temp_display_unit.value} instead, + accurately converting and translating between units as necessary. + """) + + @property + def current_conditions_section(self) -> str: + """A section describing the current conditions in the building.""" + + # TODO(mjrossetti): Add upcoming temperature forecast for at least the next + # six hours, using interpolation and caching strategies. + + return self.dedent(f""" + ## Current Conditions + + The current local time is: {self.env.current_local_timestamp.strftime('%A, %B %d, %Y %l:%M %p %Z')}. + + The current outside air temperature is: {self.observation_response_parser.outside_air_temp:.1f} Kelvin. + + Total number of zones: {len(self.env.building.zones)} + + Current number of occupants: {self.reward_info_parser.total_occupancy}. + + Current number of occupants exposed to unacceptable comfort conditions: {self.reward_info_parser.num_occupants_uncomfortable}. + + {self.zone_conditions_subsection} + + ### Current Power Consumption + + The table below shows the current energy consumption for each device: + + {self.reward_info_parser.energy_consumption_df_watts.to_markdown(index=False)} + """) + + @property + def zone_conditions_subsection(self) -> str: + """A subsection describing the current zone conditions. + + For floor-by-floor occupant comfort, see the FloorBasedPromptmaker class. + """ + + return self.dedent(f""" + ### Current Zone Temperatures + + The table below conveys the comfort conditions across all zones in the building: + + {self.reward_info_parser.zone_conditions_histogram.to_markdown(index=True)} + + The first two rows show the number of zones and the number of occupants at a specific temperature. + The row marked 'temperature setpoint range' makes a '+' for a temperature inside acceptable range, and a '-' for a temperature outside of acceptable range. + The row labeled 'count of occupants exposed' indicates the count of all occupants being exposed to unacceptable comfort conditions. + """) + + @property + def current_action_section(self) -> str: + """A section containing guidance for choosing the next action.""" + + return self.dedent(f""" + ## Current Action + + First, observe the building conditions (including occupancy levels, outside air temperature, zone air temperatures, energy consumption levels, etc.), and use this information to devise an overall strategy for your next action. + + According to your strategy, decide to turn each device ON (1) or OFF (0), using their discrete '{hybrid_env.DISCRETE_ACTION_COMMAND}' setpoints. + + For each device, also decide on values for that device's continuous setpoints. + NOTE: even if the devices are off, you still need to supply values for these continuous setpoints, however they will not be used, so it is ok to choose a value in the middle of the setpoint range. + + Provide an overall justification explaining your strategy in a sentence or two. + Also provide a justification for each setpoint you chose in a sentence or two. + + Finally, select a validity interval from the following options: {self.validity_intervals}. + The **validity interval** is the number of minutes the setpoints will remain in effect. + Choose long validity times when under steady conditions, and only apply short validity intervals when the building is undergoing high amount of change. + After the validity interval expires, you will be allowed to assign new setpoints. + + IMPORTANT NOTE: you MUST structure your response according to the "Formatting Instructions" below. + """) diff --git a/smart_control/llm/prompts/promptmaker_test.py b/smart_control/llm/prompts/promptmaker_test.py new file mode 100644 index 00000000..99579b23 --- /dev/null +++ b/smart_control/llm/prompts/promptmaker_test.py @@ -0,0 +1,460 @@ +from absl.testing import absltest +from absl.testing import parameterized +import pandas as pd +from smart_buildings.smart_control.environment import conftest as env_conftest +from smart_buildings.smart_control.llm.prompts import promptmaker +from smart_buildings.smart_control.llm.schema import output_schema +from smart_buildings.smart_control.utils.proto_parsers import observation_response_parser +from smart_buildings.smart_control.utils.proto_parsers import reward_info_parser + +WEIGHTS = { + 'energy_cost_weight': 0.3, + 'carbon_emission_weight': 0.2, + 'productivity_weight': 0.5, +} + +WEIGHTS_INCLUDED_CONTENT = ( + 'We have assigned a weight to designate the importance of each objective.' +) + +BUILDING_INFO = { + 'stories': 'two', + 'sqft': 96_000, + 'location': 'Mountain View, California', + 'name': 'SB-1', +} + + +class PromptmakerTest(absltest.TestCase): + """Tests for the Promptmaker class, with weights present but not included.""" + + def setUp(self): + super().setUp() + self.env = env_conftest.create_hybrid_action_environment( + layout=env_conftest.DEMO_LAYOUT + ) + self.env.reward_function.weights = WEIGHTS + self.pm = promptmaker.Promptmaker(env=self.env) + self.expected_promtpmaker_type = 'Promptmaker' + + def test_initialization(self): + self.assertIsInstance(self.pm, promptmaker.Promptmaker) + + def test_attributes(self): + with self.subTest(name='required_attributes'): + self.assertEqual( + self.pm.output_schema_class, + output_schema.SetpointsAction, + ) + self.assertEqual(self.pm.env, self.env) + + with self.subTest(name='configuration_attributes'): + self.assertFalse(self.pm.include_weights) + self.assertEqual(self.pm.occupancy_mode_min, 10) + self.assertEqual(self.pm.temp_display_unit, 'Fahrenheit') + + with self.subTest(name='building_info'): + building_info = self.pm.building_info + self.assertIsInstance(building_info, promptmaker.BuildingInfo) + self.assertEqual(building_info.stories, 'two') + self.assertEqual(building_info.sqft, 96_000) + self.assertEqual(building_info.location, 'Mountain View, California') + + with self.subTest(name='proto_parsers'): + self.assertFalse(self.pm.lazy_init_protos) + self.assertIsInstance( + self.pm.observation_response_parser, + observation_response_parser.ObservationResponseParser, + ) + self.assertIsInstance( + self.pm.reward_info_parser, + reward_info_parser.RewardInfoParser, + ) + + # PROPERTIES + + def test_json_metadata(self): + json_metadata = self.pm.json_metadata + + with self.subTest(name='type'): + self.assertEqual(json_metadata['type'], self.expected_promtpmaker_type) + + with self.subTest(name='include_weights'): + self.assertEqual(json_metadata['include_weights'], False) + + with self.subTest(name='occupancy_mode_min'): + self.assertEqual(json_metadata['occupancy_mode_min'], 10) + + with self.subTest(name='temp_display_unit'): + self.assertEqual(json_metadata['temp_display_unit'], 'Fahrenheit') + + with self.subTest(name='building_info'): + self.assertEqual(json_metadata['building_info'], BUILDING_INFO) + + def test_weights(self): + self.assertEqual( + self.pm.weights, + { + 'energy_cost_weight': 0.3, + 'carbon_emission_weight': 0.2, + 'comfort_weight': 0.5, + }, + ) + + def test_setpoints_df(self): + df = self.pm.setpoints_df + self.assertIsInstance(df, pd.DataFrame) + + expected_records = [ + { + 'device_id': 'air_handler_1', + 'setpoint_name': 'supervisor_run_command', + 'setpoint_type': 'DISCRETE', + 'units': 'On/Off', + 'min_native_value': 0.0, + 'max_native_value': 1.0, + }, + { + 'device_id': 'air_handler_1', + 'setpoint_name': 'supply_air_heating_temperature_setpoint', + 'setpoint_type': 'CONTINUOUS', + 'units': 'Kelvin', + 'min_native_value': 285.0, + 'max_native_value': 295.0, + }, + { + 'device_id': 'air_handler_2', + 'setpoint_name': 'supervisor_run_command', + 'setpoint_type': 'DISCRETE', + 'units': 'On/Off', + 'min_native_value': 0.0, + 'max_native_value': 1.0, + }, + { + 'device_id': 'air_handler_2', + 'setpoint_name': 'supply_air_heating_temperature_setpoint', + 'setpoint_type': 'CONTINUOUS', + 'units': 'Kelvin', + 'min_native_value': 285.0, + 'max_native_value': 295.0, + }, + { + 'device_id': 'boiler_1', + 'setpoint_name': 'supervisor_run_command', + 'setpoint_type': 'DISCRETE', + 'units': 'On/Off', + 'min_native_value': 0.0, + 'max_native_value': 1.0, + }, + { + 'device_id': 'boiler_1', + 'setpoint_name': 'supply_water_setpoint', + 'setpoint_type': 'CONTINUOUS', + 'units': 'Kelvin', + 'min_native_value': 310.0, + 'max_native_value': 350.0, + }, + ] + self.assertListEqual(df.to_dict('records'), expected_records) + + def test_validity_intervals(self): + self.assertEqual( + self.pm.validity_intervals, + [5, 10, 15, 20, 30, 45, 60, 75, 90, 120], + ) + + # PROMPT CONTENT + + def test_prompt(self): + prompt = self.pm.prompt + with self.subTest(name='objectives_section'): + self.assertIn(self.pm.objectives_section, prompt) + + with self.subTest(name='zone_info_section'): + self.assertIn(self.pm.zone_info_section, prompt) + + with self.subTest(name='occupancy_modes_section'): + self.assertIn(self.pm.occupancy_modes_section, prompt) + + with self.subTest(name='hvac_system_guidelines_section'): + self.assertIn(self.pm.hvac_system_guidelines_section, prompt) + + with self.subTest(name='action_guidelines_section'): + self.assertIn(self.pm.action_guidelines_section, prompt) + + with self.subTest(name='current_conditions_section'): + self.assertIn(self.pm.current_conditions_section, prompt) + + with self.subTest(name='current_action_section'): + self.assertIn(self.pm.current_action_section, prompt) + + with self.subTest(name='formatting_instructions_section'): + self.assertIn(self.pm.formatting_instructions_section, prompt) + + def test_objectives_section(self): + section = self.pm.objectives_section + self.assertIn('## Objectives', section) + self.assertIn('### Role', section) + self.assertIn('### Overall Goal', section) + + with self.subTest(name='includes_building_info'): + self.assertIn('**Building Information**', section) + table = self.pm.building_info_series.to_markdown(index=True) + self.assertIn(table, section) + + with self.subTest(name='weights_present_but_not_included'): + self.assertIsNotNone(self.env.reward_function.weights) + self.assertNotIn(WEIGHTS_INCLUDED_CONTENT, section) + + def test_zone_info_section(self): + section = self.pm.zone_info_section + self.assertIn('## Zone Information', section) + self.assertIn('### Zone Comfort', section) + + def test_occupancy_modes_section(self): + section = self.pm.occupancy_modes_section + self.assertIn('## Occupancy Modes', section) + self.assertIn('### Heating and Cooling Guidelines', section) + + with self.subTest(name='uses_occupancy_mode_min'): + self.assertIn( + '**Occupancy mode** is when the building has at least 10 occupants.', + section, + ) + self.assertIn( + '**Efficiency mode** is when the building has fewer than 10' + ' occupants.', + section, + ) + + def test_hvac_system_guidelines_section(self): + section = self.pm.hvac_system_guidelines_section + + with self.subTest(name='contains_section_headers'): + self.assertIn('## HVAC System Control Guidelines', section) + self.assertIn('### Devices and Setpoints', section) + self.assertIn( + '### Air Conditioner (AC) / Air Handler (AHU) Guidelines', + section, + ) + self.assertIn('### Boiler (BLR) Guidelines', section) + self.assertIn('### Zone Temperature Control Guidelines', section) + + with self.subTest(name='mentions_specific_devices'): + self.assertIn( + '**AC-1**: Air Conditioner / Air Handler Unit (for all zones on the' + ' first floor)', + section, + ) + self.assertIn( + '**AC-2**: Air Conditioner / Air Handler Unit (for all zones on the' + ' second floor)', + section, + ) + self.assertIn('**BLR**: Boiler (for both floors)', section) + + with self.subTest(name='mentions_key_setpoints'): + self.assertIn("'supervisor_run_command'", section) + self.assertIn("'static_pressure_setpoint'", section) + self.assertIn("'supply_air_temperature_setpoint'", section) + self.assertIn("'differential_pressure_setpoint'", section) + self.assertIn("'supply_water_setpoint'", section) + + def test_action_guidelines_section(self): + section = self.pm.action_guidelines_section + + with self.subTest(name='contains_section_header'): + self.assertIn('## Action Guidelines', section) + + with self.subTest(name='includes_device_setpoints_table'): + self.assertIn(self.pm.setpoints_df.to_markdown(index=False), section) + + with self.subTest(name='includes_temp_display_unit'): + self.assertIn( + 'you should communicate temperatures in Fahrenheit instead', + section, + ) + + def test_current_conditions_section(self): + section = self.pm.current_conditions_section + + with self.subTest(name='contains_section_headers'): + self.assertIn('## Current Conditions', section) + self.assertIn('### Current Zone Temperatures', section) + self.assertIn('### Current Power Consumption', section) + + with self.subTest(name='includes_current_local_time'): + self.assertIn( + 'The current local time is: Monday, June 07, 2021 12:00 PM PDT', + section, + ) + + with self.subTest(name='includes_current_outside_air_temperature'): + self.assertIn( + 'The current outside air temperature is: 295.0 Kelvin', + section, + ) + + with self.subTest(name='includes_occupant_counts'): + self.assertIn('Total number of zones: 2', section) + self.assertIn( + 'Current number of occupants: 10', + section, + ) + self.assertIn( + 'Current number of occupants exposed to unacceptable comfort' + ' conditions: 0', + section, + ) + + parser = self.pm.reward_info_parser + self.assertIsNotNone(parser) + + # pytype: disable=attribute-error + with self.subTest(name='includes_current_zone_temperatures_table'): + table = parser.zone_conditions_histogram.to_markdown(index=True) + self.assertIn(table, section) + + with self.subTest(name='includes_current_power_consumption_table'): + table = parser.energy_consumption_df_watts.to_markdown(index=False) + self.assertIn(table, section) + # pytype: enable=attribute-error + + def test_current_action_section(self): + section = self.pm.current_action_section + + with self.subTest(name='contains_section_header'): + self.assertIn('## Current Action', section) + + with self.subTest(name='specifies_discrete_action_commands'): + self.assertIn( + 'According to your strategy, decide to turn each device ON (1) or OFF' + " (0), using their discrete 'supervisor_run_command' setpoints.", + section, + ) + + with self.subTest(name='specifies_validity_interval_options'): + self.assertIn( + 'Finally, select a validity interval from the following options:' + ' [5, 10, 15, 20, 30, 45, 60, 75, 90, 120]', + section, + ) + + def test_formatting_instructions_section(self): + section = self.pm.formatting_instructions_section + self.assertIn('## Formatting Instructions', section) + + +class PromptmakerWeightsUnavailableTest(absltest.TestCase): + """Tests for the Promptmaker class, with weights not present.""" + + def setUp(self): + super().setUp() + self.env = env_conftest.create_hybrid_action_environment( + layout=env_conftest.DEMO_LAYOUT + ) + + def test_weights_not_requested_or_present(self): + pm = promptmaker.Promptmaker(env=self.env) + # Weights are not requested: + self.assertFalse(pm.include_weights) + # Weights are not present: + self.assertFalse(hasattr(self.env.reward_function, 'weights')) + + section = pm.objectives_section + self.assertNotIn(WEIGHTS_INCLUDED_CONTENT, section) + + def test_weights_requested_but_not_present(self): + pm = promptmaker.Promptmaker(env=self.env, include_weights=True) + # Weights are requested: + self.assertTrue(pm.include_weights) + # Weights are not present: + self.assertFalse(hasattr(self.env.reward_function, 'weights')) + + section = pm.objectives_section + self.assertNotIn(WEIGHTS_INCLUDED_CONTENT, section) + + +class PromptmakerWeightsInclusionTest(absltest.TestCase): + """Tests for the Promptmaker class, with weights present and included.""" + + def setUp(self): + super().setUp() + self.env = env_conftest.create_hybrid_action_environment( + layout=env_conftest.DEMO_LAYOUT + ) + self.env.reward_function.weights = WEIGHTS + self.pm = promptmaker.Promptmaker(env=self.env, include_weights=True) + + def test_weights(self): + self.assertEqual( + self.pm.weights, + { + 'energy_cost_weight': 0.3, + 'carbon_emission_weight': 0.2, + 'comfort_weight': 0.5, + }, + ) + + def test_weights_included(self): + weights = self.pm.weights + self.assertIsInstance(weights, dict) + + section = self.pm.objectives_section + self.assertIn(WEIGHTS_INCLUDED_CONTENT, section) + weights_table = pd.Series(weights, name='weight').to_markdown(index=True) + self.assertIn(weights_table, section) + + +class PromptmakerLazyInitProtosTest(parameterized.TestCase): + + ATTRIBUTE_NAMES = ( + dict( + testcase_name='base_prompt', + attribute_name='base_prompt', + ), + dict( + testcase_name='current_conditions_section', + attribute_name='current_conditions_section', + ), + ) + + def setUp(self): + super().setUp() + self.env = env_conftest.create_hybrid_action_environment( + layout=env_conftest.DEMO_LAYOUT + ) + self.pm = promptmaker.Promptmaker(self.env, lazy_init_protos=True) + + @parameterized.named_parameters(*ATTRIBUTE_NAMES) + def test_lazy_init_protos_raises_when_protos_not_set(self, attribute_name): + self.assertIsNone(self.pm._observation_response_parser) + self.assertIsNone(self.pm._reward_info_parser) + + with self.assertRaisesRegex( + ValueError, 'Observation response parser is None.' + ): + _ = getattr(self.pm, attribute_name) + + @parameterized.named_parameters(*ATTRIBUTE_NAMES) + def test_lazy_init_protos_ok_when_protos_are_set(self, attribute_name): + self.assertIsNone(self.pm._observation_response_parser) + self.assertIsNone(self.pm._reward_info_parser) + + self.pm.set_protos( + observation_response=self.env.get_observation_response(), + reward_info=self.env.get_reward_info(), + ) + self.assertIsInstance( + self.pm.observation_response_parser, + observation_response_parser.ObservationResponseParser, + ) + self.assertIsInstance( + self.pm.reward_info_parser, + reward_info_parser.RewardInfoParser, + ) + _ = getattr(self.pm, attribute_name) # No error thrown. + + +if __name__ == '__main__': + absltest.main() diff --git a/smart_control/llm/prompts/sb1/examples/example_floor_based_prompt.md b/smart_control/llm/prompts/sb1/examples/example_floor_based_prompt.md new file mode 100644 index 00000000..c815135f --- /dev/null +++ b/smart_control/llm/prompts/sb1/examples/example_floor_based_prompt.md @@ -0,0 +1,228 @@ +# Agent Instructions + +## Objectives + +### Role + +You are a skilled, experienced, and innovative operator of a commercial office building. +You possess in-depth and complete knowledge about HVAC systems, as well as ASHRAE standards and certifications. +Your job is to optimally control HVAC devices in a given commercial office building. + +**Building Information**: + +| | building_info | +|:---------|:--------------------------| +| name | SB-1 | +| stories | two | +| sqft | 96000 | +| location | Mountain View, California | + +### Overall Goal + +As the building operator, your **Optimal Control Objectives** are to: + ++ Minimize energy consumption / costs, and ++ Minimize carbon emissions, and ++ Maintain occupant comfort (a.k.a. productivity) + +This is a multi-objective optimization problem, where you must balance competing objectives. + +### Reward Function Weights + +We have assigned a weight to designate the importance of each objective. +Your job is to maximize the weighted sum of the objectives, placing a higher priority on objectives with greater weights. +The weights are designated in the table below: + +| | weight | +|:-----------------------|---------:| +| energy_cost_weight | 0.2 | +| carbon_emission_weight | 0.2 | +| comfort_weight | 0.6 | + +## Zone Information + +A **zone** is a room, or space in the office building that is potentially occupied by humans, and must be conditioned for comfort when occupied. + +### Zone Comfort + +The **zone air temperature** is the average temperature in a zone and the measure of comfort in the zone. + +The **zone air heating setpoint** is the minimum temperature that zone is allowed to be, without actively heating the zone. +It's like the minimum of the occupant comfort range. +The **zone air cooling setpoint** is the maximum temperature that zone is allowed to be, without actively cooling the zone. +It's like the maximum of the occupant comfort range. +The zone air heating temperature setpoint is always below the zone air cooling temperature setpoint. + +Ideally: `zone air heating setpoint < zone air temperature if occupied < zone air cooling setpoint` + +## Occupancy Modes + +You should operate the building in an occupancy mode and an efficiency mode. + +**Occupancy mode** is when the building has at least 10 occupants. +When in occupancy mode, you should try to maintain zone air temperatures within comfort range (for all occupied zones), while also minimizing energy consumption and carbon emissions. + +**Efficiency mode** is when the building has fewer than 10 occupants. +When in efficiency mode, your only objective should be to SIGNIFICANTLY reduce energy consumption and carbon emissions. + +### Heating and Cooling Guidelines + +To save energy, you should transition from efficiency mode to occupancy mode in the morning as late as possible, but early enough to ensure the building is in setpoints when the occupants arrive. +Depending on the outside air temperature, the building will take some time to get into setpoint ranges, especially in the mornings before transitioning from efficiency mode to occupancy mode. +Therefore, you must apply heating or cooling early enough to ensure that the setpoint temperatures are met before occupancy mode setpoints are applied. + +Time it takes to increase zone air temperature by 1 degree Fahrenheit: + ++ Under standard conditions with lower outside air temperature, and active heating, it takes 10 minutes. ++ Under standard conditions with higher outside air temperature, and no active cooling, it takes 20 minutes. + +Time it takes to decrease zone air temperature by 1 degree Fahrenheit: + ++ Under standard conditions with higher outside air temperature, and active cooling, it takes 10 minutes. ++ Under standard conditions with lower outside air temperature, and with no active heating, it takes 20 minutes. + +## HVAC System Control Guidelines + +There are two systems under your control, with three devices total. +The Air Handler System (AHS) includes two air handler / air conditioner devices (AC-1 and AC-2). +The Hot Water System (HWS) includes one boiler device (BLR). + +### Devices and Setpoints + +**AC-1**: Air Conditioner / Air Handler Unit (for all zones on the first floor) + +* 'supervisor_run_command': you can turn the device ON (1) and OFF (0) +* 'static_pressure_setpoint': you can increase/decrease airflow by increasing/decreasing static pressure +* 'supply_air_temperature_setpoint': you can cool the zones by lowering the supply air temperature + +**AC-2**: Air Conditioner / Air Handler Unit (for all zones on the second floor) + +* 'supervisor_run_command': you can turn the device ON (1) and OFF (0) +* 'static_pressure_setpoint': you can increase/decrease airflow by increasing/decreasing static pressure +* 'supply_air_temperature_setpoint': you can cool the zones by lowering the supply air temperature + +**BLR**: Boiler (for both floors): + +* 'supervisor_run_command': you can turn the device ON (1) and OFF (0) +* 'differential_pressure_setpoint': you can increase/decrease water flow to the zones by increasing/decreasing differential pressure +* 'supply_water_setpoint': you can heat the zones by increasing the water supply temperature + +### Air Conditioner (AC) / Air Handler (AHU) Guidelines + +Turning on an AC will consume electricity by running the air blowers and running the refrigeration compressors. +Turning them off will not consume any electricity, but will also remove air cooling and ventilation. + +Lowering an AC's supply air temperature below outside air temperature will cause the compressor to run, consuming electricity, and will cool the zones. +Setting the supply air temperature only enables you to cool, but not heat the zones. + +Increasing an AC's static pressure will increase air circulation through the zones, which results in cooling or heating the zones. + +### Boiler (BLR) Guidelines + +Lowering the boiler's supply water temperature will reduce carbon emission, but will also reduce the ability to heat zones. + +### Zone Temperature Control Guidelines + +If a zone is occupied and the zone air temperature is below the zone air heating temperature setpoint, the VAV in the zone will request air flow and hot water circulation to heat the zone. +You control air flow by managing the AHU static pressure setpoints, and hot water circulation by managing the HWS differential pressure and supply water temperature setpoints. + +If the zone is occupied and the zone air temperature is above the zone air cooling temperature setpoint, the VAV in the zone will request cool air from the AHU. +You control the amount of cooling by managing the AHU static pressure and supply air temperature setpoints. + +## Action Guidelines + +Throughout the day, you will be prompted to choose your actions. +Your actions will be used to control the HVAC systems in the building. +An action requires a value and justification for each of the device setpoints listed below. + +| device_id | setpoint_name | setpoint_type | units | min_native_value | max_native_value | +|:------------|:--------------------------------------|:----------------|:--------|-------------------:|-------------------:| +| ahs | ahu_1_static_pressure_setpoint | CONTINUOUS | Pascal | 0 | 20000 | +| ahs | ahu_1_supervisor_run_command | DISCRETE | On/Off | 0 | 1 | +| ahs | ahu_1_supply_air_temperature_setpoint | CONTINUOUS | Kelvin | 285 | 305 | +| ahs | ahu_2_static_pressure_setpoint | CONTINUOUS | Pascal | 0 | 20000 | +| ahs | ahu_2_supervisor_run_command | DISCRETE | On/Off | 0 | 1 | +| ahs | ahu_2_supply_air_temperature_setpoint | CONTINUOUS | Kelvin | 285 | 305 | +| hws | differential_pressure | CONTINUOUS | Pascal | 0 | 20 | +| hws | supervisor_run_command | DISCRETE | On/Off | 0 | 1 | +| hws | supply_water_setpoint | CONTINUOUS | Kelvin | 310 | 350 | + +Note about temperature units: +All temperatures will be reported to you in Kelvin. +The temperatures you choose to set should be in Kelvin. +However, in your textual responses and justifications only, +you should communicate temperatures in Fahrenheit instead, +accurately converting and translating between units as necessary. + +## Current Conditions + +The current local time is: Monday, December 16, 2024 12:00 AM PST. + +The current outside air temperature is: 285.1 Kelvin. + +Total number of zones: 126 + +Current number of occupants: 0. + +Current number of occupants exposed to unacceptable comfort conditions: 0. + +### Current Zone Temperatures + +The table below conveys the comfort conditions across all zones in the building, by floor: + +| | 290.0 | 291.0 | 292.0 | 293.0 | 294.0 | 295.0 | 296.0 | 297.0 | 298.0 | 299.0 | 300.0 | +|:----------------|:--------|:--------|:--------|:--------|:--------|:--------|:--------|:--------|:--------|:--------|:--------| +| occupancy_count | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | +| setpoint_mask | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | +| setpoint_range | + | + | + | + | + | + | + | + | + | - | - | +| exposed_count | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | +| occ@floor0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | + +The row 'occupancy_count' shows the total number of occupants building-wide at a specific temperature. +The row 'setpoint_range' indicates with '+' if the temperature is inside the acceptable range, and '-' if it is outside. +The row 'exposed_count' indicates the count of occupants being exposed to unacceptable comfort conditions. +The rows starting with 'occ@floor' show the normalized distribution of zone counts for each floor at that temperature. + +### Current Power Consumption + +The table below shows the current energy consumption for each device: + +| device_type | device_id | metric | description | rate_watts | consumption_kwh | +|:--------------|:------------|:----------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------:|------------------:| +| AHU | ahs | blower_electrical_energy_rate | Cumulative electrical power in W applied to blowers. | 0 | 0 | +| AHU | ahs | air_conditioning_electrical_energy_rate | Cumulative electrical energy rate applied in W for air conditioning. This represents the total power applied for running refrigeration or heat pump cycles (includes running a compressor and pumps to recirculate refrigerant). | 0 | 0 | +| BLR | hws | pump_electrical_energy_rate | Cumulative electrical power in W for water recirculation pumps. | 0 | 0 | +| BLR | hws | natural_gas_heating_energy_rate | Energy rate consumed in W by natural gas for heating water. | 467.875 | 0.0389896 | + +## Current Action + +First, observe the building conditions (including occupancy levels, outside air temperature, zone air temperatures, energy consumption levels, etc.), and use this information to devise an overall strategy for your next action. + +According to your strategy, decide to turn each device ON (1) or OFF (0), using their discrete 'supervisor_run_command' setpoints. + +For each device, also decide on values for that device's continuous setpoints. +NOTE: even if the devices are off, you still need to supply values for these continuous setpoints, however they will not be used, so it is ok to choose a value in the middle of the setpoint range. + +Provide an overall justification explaining your strategy in a sentence or two. +Also provide a justification for each setpoint you chose in a sentence or two. + +Finally, select a validity interval from the following options: [5, 10, 15, 20, 30, 45, 60, 75, 90, 120]. +The **validity interval** is the number of minutes the setpoints will remain in effect. +Choose long validity times when under steady conditions, and only apply short validity intervals when the building is undergoing high amount of change. +After the validity interval expires, you will be allowed to assign new setpoints. + +IMPORTANT NOTE: you MUST structure your response according to the "Formatting Instructions" below. + +## Formatting Instructions + +IMPORTANT: The output MUST be a single, valid JSON object conforming to the schema below. +Do NOT include any other text, explanations, pleasantries, or any other content before or after the JSON object. +The output should be formatted as a JSON instance that conforms to the JSON schema below. + +As an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]} +the object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted. + +Here is the output schema: +``` +{"$defs": {"DeviceSetpoint": {"description": "A single device setpoint.\n\nA device is uniquely identified by a composite key consisting of the device\nidentifier and the setpoint name.\n\nAttributes:\n device_id: The unique identifier of the device (e.g. 'boiler-123-xyz').\n setpoint_name: The name of the setpoint (e.g. 'supply_water_temperature').\n setpoint_value: The requested value to be set (e.g. 120.0).\n justification: The reason for choosing this specific device setting.", "properties": {"device_id": {"description": "The unique identifier of the device.", "title": "Device Id", "type": "string"}, "setpoint_name": {"description": "The name of the setpoint.", "title": "Setpoint Name", "type": "string"}, "setpoint_value": {"description": "The requested value to be set.", "title": "Setpoint Value", "type": "number"}, "justification": {"description": "The reason for choosing this specific device setting.", "title": "Justification", "type": "string"}}, "required": ["device_id", "setpoint_name", "setpoint_value", "justification"], "title": "DeviceSetpoint", "type": "object"}}, "description": "A flexible action model for setting any number of setpoints.\n\nAttributes:\n timestamp: The time the action is taken (in the building's local timezone).\n justification: The overall reason for taking this action. Includes a brief\n description of why the action is justified, as well as the desired\n outcome of the action as a whole.\n setpoints: A list of setpoints.\n validity_interval: The amount of time in minutes the setpoints should remain\n in effect before prompting for a new action.", "properties": {"timestamp": {"description": "The time the action is taken, formatted as 'YYYY-MM-DD HH:MM:SS', assumed to be in the building's local timezone.", "title": "Timestamp", "type": "string"}, "justification": {"description": "The overall reason for taking this action. Includes a brief description of why the action is justified, as well as the desired outcome of the action as a whole.", "title": "Justification", "type": "string"}, "setpoints": {"description": "A list of setpoints.", "items": {"$ref": "#/$defs/DeviceSetpoint"}, "title": "Setpoints", "type": "array"}, "validity_interval": {"description": "The number of minutes the setpoints should remain in effect before prompting for a new action.", "enum": [5, 10, 15, 20, 30, 45, 60, 75, 90, 120], "title": "Validity Interval", "type": "integer"}}, "required": ["timestamp", "justification", "setpoints", "validity_interval"]} +``` diff --git a/smart_control/llm/prompts/sb1/examples/example_prompt.md b/smart_control/llm/prompts/sb1/examples/example_prompt.md new file mode 100644 index 00000000..064b92a2 --- /dev/null +++ b/smart_control/llm/prompts/sb1/examples/example_prompt.md @@ -0,0 +1,226 @@ +# Agent Instructions + +## Objectives + +### Role + +You are a skilled, experienced, and innovative operator of a commercial office building. +You possess in-depth and complete knowledge about HVAC systems, as well as ASHRAE standards and certifications. +Your job is to optimally control HVAC devices in a given commercial office building. + +**Building Information**: + +| | building_info | +|:---------|:--------------------------| +| name | SB-1 | +| stories | two | +| sqft | 96000 | +| location | Mountain View, California | + +### Overall Goal + +As the building operator, your **Optimal Control Objectives** are to: + ++ Minimize energy consumption / costs, and ++ Minimize carbon emissions, and ++ Maintain occupant comfort (a.k.a. productivity) + +This is a multi-objective optimization problem, where you must balance competing objectives. + +### Reward Function Weights + +We have assigned a weight to designate the importance of each objective. +Your job is to maximize the weighted sum of the objectives, placing a higher priority on objectives with greater weights. +The weights are designated in the table below: + +| | weight | +|:-----------------------|---------:| +| energy_cost_weight | 0.2 | +| carbon_emission_weight | 0.2 | +| comfort_weight | 0.6 | + +## Zone Information + +A **zone** is a room, or space in the office building that is potentially occupied by humans, and must be conditioned for comfort when occupied. + +### Zone Comfort + +The **zone air temperature** is the average temperature in a zone and the measure of comfort in the zone. + +The **zone air heating setpoint** is the minimum temperature that zone is allowed to be, without actively heating the zone. +It's like the minimum of the occupant comfort range. +The **zone air cooling setpoint** is the maximum temperature that zone is allowed to be, without actively cooling the zone. +It's like the maximum of the occupant comfort range. +The zone air heating temperature setpoint is always below the zone air cooling temperature setpoint. + +Ideally: `zone air heating setpoint < zone air temperature if occupied < zone air cooling setpoint` + +## Occupancy Modes + +You should operate the building in an occupancy mode and an efficiency mode. + +**Occupancy mode** is when the building has at least 10 occupants. +When in occupancy mode, you should try to maintain zone air temperatures within comfort range (for all occupied zones), while also minimizing energy consumption and carbon emissions. + +**Efficiency mode** is when the building has fewer than 10 occupants. +When in efficiency mode, your only objective should be to SIGNIFICANTLY reduce energy consumption and carbon emissions. + +### Heating and Cooling Guidelines + +To save energy, you should transition from efficiency mode to occupancy mode in the morning as late as possible, but early enough to ensure the building is in setpoints when the occupants arrive. +Depending on the outside air temperature, the building will take some time to get into setpoint ranges, especially in the mornings before transitioning from efficiency mode to occupancy mode. +Therefore, you must apply heating or cooling early enough to ensure that the setpoint temperatures are met before occupancy mode setpoints are applied. + +Time it takes to increase zone air temperature by 1 degree Fahrenheit: + ++ Under standard conditions with lower outside air temperature, and active heating, it takes 10 minutes. ++ Under standard conditions with higher outside air temperature, and no active cooling, it takes 20 minutes. + +Time it takes to decrease zone air temperature by 1 degree Fahrenheit: + ++ Under standard conditions with higher outside air temperature, and active cooling, it takes 10 minutes. ++ Under standard conditions with lower outside air temperature, and with no active heating, it takes 20 minutes. + +## HVAC System Control Guidelines + +There are two systems under your control, with three devices total. +The Air Handler System (AHS) includes two air handler / air conditioner devices (AC-1 and AC-2). +The Hot Water System (HWS) includes one boiler device (BLR). + +### Devices and Setpoints + +**AC-1**: Air Conditioner / Air Handler Unit (for all zones on the first floor) + +* 'supervisor_run_command': you can turn the device ON (1) and OFF (0) +* 'static_pressure_setpoint': you can increase/decrease airflow by increasing/decreasing static pressure +* 'supply_air_temperature_setpoint': you can cool the zones by lowering the supply air temperature + +**AC-2**: Air Conditioner / Air Handler Unit (for all zones on the second floor) + +* 'supervisor_run_command': you can turn the device ON (1) and OFF (0) +* 'static_pressure_setpoint': you can increase/decrease airflow by increasing/decreasing static pressure +* 'supply_air_temperature_setpoint': you can cool the zones by lowering the supply air temperature + +**BLR**: Boiler (for both floors): + +* 'supervisor_run_command': you can turn the device ON (1) and OFF (0) +* 'differential_pressure_setpoint': you can increase/decrease water flow to the zones by increasing/decreasing differential pressure +* 'supply_water_setpoint': you can heat the zones by increasing the water supply temperature + +### Air Conditioner (AC) / Air Handler (AHU) Guidelines + +Turning on an AC will consume electricity by running the air blowers and running the refrigeration compressors. +Turning them off will not consume any electricity, but will also remove air cooling and ventilation. + +Lowering an AC's supply air temperature below outside air temperature will cause the compressor to run, consuming electricity, and will cool the zones. +Setting the supply air temperature only enables you to cool, but not heat the zones. + +Increasing an AC's static pressure will increase air circulation through the zones, which results in cooling or heating the zones. + +### Boiler (BLR) Guidelines + +Lowering the boiler's supply water temperature will reduce carbon emission, but will also reduce the ability to heat zones. + +### Zone Temperature Control Guidelines + +If a zone is occupied and the zone air temperature is below the zone air heating temperature setpoint, the VAV in the zone will request air flow and hot water circulation to heat the zone. +You control air flow by managing the AHU static pressure setpoints, and hot water circulation by managing the HWS differential pressure and supply water temperature setpoints. + +If the zone is occupied and the zone air temperature is above the zone air cooling temperature setpoint, the VAV in the zone will request cool air from the AHU. +You control the amount of cooling by managing the AHU static pressure and supply air temperature setpoints. + +## Action Guidelines + +Throughout the day, you will be prompted to choose your actions. +Your actions will be used to control the HVAC systems in the building. +An action requires a value and justification for each of the device setpoints listed below. + +| device_id | setpoint_name | setpoint_type | units | min_native_value | max_native_value | +|:------------|:--------------------------------------|:----------------|:--------|-------------------:|-------------------:| +| ahs | ahu_1_static_pressure_setpoint | CONTINUOUS | Pascal | 0 | 20000 | +| ahs | ahu_1_supervisor_run_command | DISCRETE | On/Off | 0 | 1 | +| ahs | ahu_1_supply_air_temperature_setpoint | CONTINUOUS | Kelvin | 285 | 305 | +| ahs | ahu_2_static_pressure_setpoint | CONTINUOUS | Pascal | 0 | 20000 | +| ahs | ahu_2_supervisor_run_command | DISCRETE | On/Off | 0 | 1 | +| ahs | ahu_2_supply_air_temperature_setpoint | CONTINUOUS | Kelvin | 285 | 305 | +| hws | differential_pressure | CONTINUOUS | Pascal | 0 | 20 | +| hws | supervisor_run_command | DISCRETE | On/Off | 0 | 1 | +| hws | supply_water_setpoint | CONTINUOUS | Kelvin | 310 | 350 | + +Note about temperature units: +All temperatures will be reported to you in Kelvin. +The temperatures you choose to set should be in Kelvin. +However, in your textual responses and justifications only, +you should communicate temperatures in Fahrenheit instead, +accurately converting and translating between units as necessary. + +## Current Conditions + +The current local time is: Monday, December 16, 2024 12:00 AM PST. + +The current outside air temperature is: 285.1 Kelvin. + +Total number of zones: 126 + +Current number of occupants: 0. + +Current number of occupants exposed to unacceptable comfort conditions: 0. + +### Current Zone Temperatures + +The table below conveys the comfort conditions across all zones in the building: + +| | 290.0K | 291.0K | 292.0K | 293.0K | 294.0K | 295.0K | 296.0K | 297.0K | 298.0K | 299.0K | 300.0K | +|:---------------------------|:---------|:---------|:---------|:---------|:---------|:---------|:---------|:---------|:---------|:---------|:---------| +| count of zones | 0 | 0 | 0 | 0 | 126 | 0 | 0 | 0 | 0 | 0 | 0 | +| count of occupants | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | +| temperature setpoint range | + | + | + | + | + | + | + | + | + | - | - | +| count of occupants exposed | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | + +The first two rows show the number of zones and the number of occupants at a specific temperature. +The row marked 'temperature setpoint range' makes a '+' for a temperature inside acceptable range, and a '-' for a temperature outside of acceptable range. +The row labeled 'count of occupants exposed' indicates the count of all occupants being exposed to unacceptable comfort conditions. + +### Current Power Consumption + +The table below shows the current energy consumption for each device: + +| device_type | device_id | metric | description | rate_watts | consumption_kwh | +|:--------------|:------------|:----------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------:|------------------:| +| AHU | ahs | blower_electrical_energy_rate | Cumulative electrical power in W applied to blowers. | 0 | 0 | +| AHU | ahs | air_conditioning_electrical_energy_rate | Cumulative electrical energy rate applied in W for air conditioning. This represents the total power applied for running refrigeration or heat pump cycles (includes running a compressor and pumps to recirculate refrigerant). | 0 | 0 | +| BLR | hws | pump_electrical_energy_rate | Cumulative electrical power in W for water recirculation pumps. | 0 | 0 | +| BLR | hws | natural_gas_heating_energy_rate | Energy rate consumed in W by natural gas for heating water. | 467.875 | 0.0389896 | + +## Current Action + +First, observe the building conditions (including occupancy levels, outside air temperature, zone air temperatures, energy consumption levels, etc.), and use this information to devise an overall strategy for your next action. + +According to your strategy, decide to turn each device ON (1) or OFF (0), using their discrete 'supervisor_run_command' setpoints. + +For each device, also decide on values for that device's continuous setpoints. +NOTE: even if the devices are off, you still need to supply values for these continuous setpoints, however they will not be used, so it is ok to choose a value in the middle of the setpoint range. + +Provide an overall justification explaining your strategy in a sentence or two. +Also provide a justification for each setpoint you chose in a sentence or two. + +Finally, select a validity interval from the following options: [5, 10, 15, 20, 30, 45, 60, 75, 90, 120]. +The **validity interval** is the number of minutes the setpoints will remain in effect. +Choose long validity times when under steady conditions, and only apply short validity intervals when the building is undergoing high amount of change. +After the validity interval expires, you will be allowed to assign new setpoints. + +IMPORTANT NOTE: you MUST structure your response according to the "Formatting Instructions" below. + +## Formatting Instructions + +IMPORTANT: The output MUST be a single, valid JSON object conforming to the schema below. +Do NOT include any other text, explanations, pleasantries, or any other content before or after the JSON object. +The output should be formatted as a JSON instance that conforms to the JSON schema below. + +As an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]} +the object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted. + +Here is the output schema: +``` +{"$defs": {"DeviceSetpoint": {"description": "A single device setpoint.\n\nA device is uniquely identified by a composite key consisting of the device\nidentifier and the setpoint name.\n\nAttributes:\n device_id: The unique identifier of the device (e.g. 'boiler-123-xyz').\n setpoint_name: The name of the setpoint (e.g. 'supply_water_temperature').\n setpoint_value: The requested value to be set (e.g. 120.0).\n justification: The reason for choosing this specific device setting.", "properties": {"device_id": {"description": "The unique identifier of the device.", "title": "Device Id", "type": "string"}, "setpoint_name": {"description": "The name of the setpoint.", "title": "Setpoint Name", "type": "string"}, "setpoint_value": {"description": "The requested value to be set.", "title": "Setpoint Value", "type": "number"}, "justification": {"description": "The reason for choosing this specific device setting.", "title": "Justification", "type": "string"}}, "required": ["device_id", "setpoint_name", "setpoint_value", "justification"], "title": "DeviceSetpoint", "type": "object"}}, "description": "A flexible action model for setting any number of setpoints.\n\nAttributes:\n timestamp: The time the action is taken (in the building's local timezone).\n justification: The overall reason for taking this action. Includes a brief\n description of why the action is justified, as well as the desired\n outcome of the action as a whole.\n setpoints: A list of setpoints.\n validity_interval: The amount of time in minutes the setpoints should remain\n in effect before prompting for a new action.", "properties": {"timestamp": {"description": "The time the action is taken, formatted as 'YYYY-MM-DD HH:MM:SS', assumed to be in the building's local timezone.", "title": "Timestamp", "type": "string"}, "justification": {"description": "The overall reason for taking this action. Includes a brief description of why the action is justified, as well as the desired outcome of the action as a whole.", "title": "Justification", "type": "string"}, "setpoints": {"description": "A list of setpoints.", "items": {"$ref": "#/$defs/DeviceSetpoint"}, "title": "Setpoints", "type": "array"}, "validity_interval": {"description": "The number of minutes the setpoints should remain in effect before prompting for a new action.", "enum": [5, 10, 15, 20, 30, 45, 60, 75, 90, 120], "title": "Validity Interval", "type": "integer"}}, "required": ["timestamp", "justification", "setpoints", "validity_interval"]} +``` diff --git a/smart_control/llm/prompts/sb1/generator.py b/smart_control/llm/prompts/sb1/generator.py new file mode 100644 index 00000000..6d2b6f49 --- /dev/null +++ b/smart_control/llm/prompts/sb1/generator.py @@ -0,0 +1,50 @@ +"""Example prompt generator for Building 'SB-1'. + +To run this script using blaze: + +```sh +blaze run //third_party/py/smart_buildings/smart_control/llm/prompts/sb1:generator +``` + +Arguments: + + --include_weights: Whether to include weights in the prompt (default: True). + --md_filename: Filename for the markdown file (default: 'example_prompt.md'). +""" # pylint: disable=line-too-long + +import os + +from absl import app +from absl import flags +from smart_buildings.smart_control.configs.resources.sb1.config_utils import full_config +from smart_buildings.smart_control.llm.prompts import generator +from smart_buildings.smart_control.llm.prompts.sb1 import sb1_promptmaker + +INCLUDE_WEIGHTS = flags.DEFINE_boolean( + "include_weights", True, "Include weights in the prompt." +) + + +def main(_) -> None: + """Loads environment, creates prompt, and writes to markdown file.""" + + print("SETTING GIN CONFIG...") + full_config.set_gin_config() + + generator.write_prompt_md( + promptmaker_class=sb1_promptmaker.SB1Promptmaker, + include_weights=INCLUDE_WEIGHTS.value, + dirpath=os.path.dirname(os.path.realpath(__file__)), + filename="example_prompt.md", + ) + + generator.write_prompt_md( + promptmaker_class=sb1_promptmaker.SB1FloorBasedPromptmaker, + include_weights=INCLUDE_WEIGHTS.value, + dirpath=os.path.dirname(os.path.realpath(__file__)), + filename="example_floor_based_prompt.md", + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/smart_control/llm/prompts/sb1/sb1_promptmaker.py b/smart_control/llm/prompts/sb1/sb1_promptmaker.py new file mode 100644 index 00000000..625c00b4 --- /dev/null +++ b/smart_control/llm/prompts/sb1/sb1_promptmaker.py @@ -0,0 +1,16 @@ +"""Promptmaker for Building SB-1. + +This is a building-specific promptmaker used to generate prompts for controlling +Building 'SB-1'. +""" + +from smart_buildings.smart_control.llm.prompts import floor_based_promptmaker as fbpm +from smart_buildings.smart_control.llm.prompts import promptmaker as pm + + +class SB1Promptmaker(pm.Promptmaker): + """Promptmaker for Building 'SB-1'.""" + + +class SB1FloorBasedPromptmaker(fbpm.FloorBasedPromptmaker): + """Floor-based Promptmaker for Building 'SB-1'.""" diff --git a/smart_control/llm/prompts/sb1/sb1_promptmaker_test.py b/smart_control/llm/prompts/sb1/sb1_promptmaker_test.py new file mode 100644 index 00000000..18632e50 --- /dev/null +++ b/smart_control/llm/prompts/sb1/sb1_promptmaker_test.py @@ -0,0 +1,50 @@ +from absl.testing import absltest +from smart_buildings.smart_control.environment import conftest as env_conftest +from smart_buildings.smart_control.llm.prompts import promptmaker_test +from smart_buildings.smart_control.llm.prompts.sb1 import sb1_promptmaker + + +class SB1PromptmakerTest(promptmaker_test.PromptmakerTest): + + def setUp(self): + super().setUp() + self.env = env_conftest.create_hybrid_action_environment( + layout=env_conftest.DEMO_LAYOUT + ) + self.env.reward_function.weights = promptmaker_test.WEIGHTS + self.pm = sb1_promptmaker.SB1Promptmaker(env=self.env) + self.expected_promtpmaker_type = 'SB1Promptmaker' + + def test_initialization(self): + self.assertIsInstance(self.pm, sb1_promptmaker.SB1Promptmaker) + + +class SB1FloorBasedPromptmakerTest(promptmaker_test.PromptmakerTest): + + def setUp(self): + super().setUp() + self.env = env_conftest.create_hybrid_action_environment( + layout=env_conftest.DEMO_LAYOUT + ) + self.env.reward_function.weights = promptmaker_test.WEIGHTS + self.pm = sb1_promptmaker.SB1FloorBasedPromptmaker(env=self.env) + self.expected_promtpmaker_type = 'SB1FloorBasedPromptmaker' + + def test_initialization(self): + self.assertIsInstance(self.pm, sb1_promptmaker.SB1FloorBasedPromptmaker) + + def test_current_conditions_section(self): + section = self.pm.current_conditions_section + self.assertIn('## Current Conditions', section) + self.assertIn('### Current Zone Temperatures', section) + self.assertIn('by floor:', section) + self.assertIn("The row 'occupancy_count'", section) + self.assertIn("The rows starting with 'occ@floor'", section) + + # Check if the table is present + table = self.pm.zone_conditions_histogram_by_floor.to_markdown(index=True) + self.assertIn(table, section) + + +if __name__ == '__main__': + absltest.main() diff --git a/smart_control/llm/schema/action_context.py b/smart_control/llm/schema/action_context.py new file mode 100644 index 00000000..def5561d --- /dev/null +++ b/smart_control/llm/schema/action_context.py @@ -0,0 +1,385 @@ +"""Action Context is an LLM output schema with awareness of the environment. + +**Setpoint Content Validations** + +The Action Context uses its environment to validate the content of the setpoints +in the requested action. If a given setpoint value is outside the valid range +as defined by that setpoint's action normalizer (thus exceeding the guardrails), +an error will be raised if the `clip` option is set to `False`. However, by +default, if the `clip` option is set to `True`, the setpoints will be clipped to +the bounds of the valid setpoint range, and a record of the error will be stored +(instead of being raised). For example, if the valid range for a setpoint is +[10, 20], and the LLM requests a value of 25, with clipping enabled, the value +will be clipped to 20, and a record of the error will be available in the +`guardrails_exceeded` attribute. + +**Action Formatting** + +The Action Context also uses its environment to convert the setpoints into a +format suitable for stepping the environment. The `ActionContext` class should +be used in conjunction with a normal continuous action `Environment`, whereas +the `HybridActionContext` class should be used with a `HybridActionEnvironment`. +Regardless of which class is used, the `get_action` method produces a properly +formatted action that can be used to step the environment. +""" + +import abc +from collections.abc import Collection +from collections.abc import Sequence +import dataclasses +import json +from typing import Any, Literal, Self + +import pandas as pd +import pydantic +from smart_buildings.smart_control.environment import environment +from smart_buildings.smart_control.environment import hybrid_action_environment +from smart_buildings.smart_control.llm.schema import output_schema + +SteppableActionType = ( + environment.NormalizedActionValues | hybrid_action_environment.HybridAction +) + + +# +# ERRORS +# + + +class GuardrailsExceededError(ValueError): + """Requested setpoint value is outside the normalizer range.""" + + +@dataclasses.dataclass(frozen=True) +class GuardrailsExceededRecord: + """Information about a requested setpoint value that is out of range. + + Attributes: + device_id: The device identifer. + setpoint_name: The name of the setpoint for the given device. + requested_value: The requested setpoint value. + setpoint_range: The valid range of setpoint values accepted by the + environment. + clipped_value: The setpoint value after being clipped to the valid range. + """ + + device_id: str + setpoint_name: str + requested_value: float + setpoint_range: tuple[float, float] + clipped_value: float + + +# +# SCHEMA +# + + +class Steppable(abc.ABC): + """An action schema that produces an action that can step an environment.""" + + @abc.abstractmethod + def get_action( + self, + ) -> SteppableActionType: + """Returns an action used to step the environment.""" + + +class ActionContext(output_schema.SetpointsAction, Steppable): + """A `SetpointsAction` with awareness of the environment. + + This `ActionContext` class should be used in conjunction with a normal + continuous action `Environment`. + """ + + # We are using the environment for validation of the setpoints, but it is not + # part of the Pydantic model schema itself. Because validation runs during + # the parent class initialization, the environment must be assigned + # beforehand, so we use an object.__setattr__() approach. However, Pydantic + # v2's __getattr__ intercepts access to the environment during validation, + # which causes an AttributeError. Defining __slots__ forces the environment + # to be managed via Python's slot mechanism, bypassing Pydantic's + # __getattr__ and allowing it to be accessed during validation. + # + # TODO: b/496194630 - It might make more sense to make this a separate class, + # instead of inheriting from the schema class. + # + __slots__ = ("_env", "_clip", "_guardrails_exceeded") + _env: environment.Environment + _clip: bool + _guardrails_exceeded: list[GuardrailsExceededRecord] + + def __init__( + self, env: environment.Environment, *, clip: bool = True, **kwargs + ): + """Initializes the instance. + + Args: + env: The environment to use for validation. + clip: Governs the behavior when an agent requests a setpoint value that is + outside of the valid range. If `True`, clips the setpoint values to the + bounds of the valid range, and logs a record of the error, but does not + halt execution. This is the default behavior. Otherwise, if `False`, + will raise a `GuardrailsExceededError` and halt execution. + **kwargs: Arguments to pass to initialize the `SetpointsAction` schema. + + Raises: + GuardrailsExceededError: If `clip` is `False` and any setpoint value is + outside the valid range defined by the environment's normalizers. + """ + object.__setattr__(self, "_env", env) + object.__setattr__(self, "_clip", clip) + object.__setattr__(self, "_guardrails_exceeded", []) + super().__init__(**kwargs) + + @classmethod + def from_json( + cls, txt: str, env: environment.Environment, *, clip: bool = True + ) -> Self: + """Creates an instance from a JSON string, while passing extra attributes. + + The LLM responds with a JSON-formatted string, but we need to pass the + environment and clip attributes to the class constructor as well. So this + method solves that problem. + + This method is meant to act as a replacement for Pydantic's + `model_validate_json` method, which we would normally use, but cannot use + with this class due to its custom `__init__` signature. + + Args: + txt: The JSON-formatted string to parse and validate. + env: The environment to use for validation. + clip: Governs the behavior when an agent requests a setpoint value that is + outside of the valid range. If `True`, clips the setpoint values to the + bounds of the valid range, and logs a record of the error, but does not + halt execution. This is the default behavior. Otherwise, if `False`, + will raise a `GuardrailsExceededError` and halt execution. + + Returns: + An instance of the class. + """ + return cls(env=env, clip=clip, **json.loads(txt)) + + @property + def env(self) -> environment.Environment: + """The environment.""" + return self._env + + @property + def clip(self) -> bool: + """Whether to clip setpoint values to the bounds of the valid range.""" + return self._clip + + @property + def guardrails_exceeded(self) -> Collection[GuardrailsExceededRecord]: + """A list of guardrails errors that occurred during validation.""" + return self._guardrails_exceeded + + @pydantic.model_validator(mode="after") + def validate_setpoint_contents(self) -> Self: + """Ensures all env action names are present, and values are in range.""" + setpoint_action_names = set() + + # CHECK SETPOINTS THAT ARE PRESENT IN THE SCHEMA + for setpoint in self.setpoints: + device_id = setpoint.device_id + setpoint_name = setpoint.setpoint_name + try: + action_name = self.env.id_map[(device_id, setpoint_name)] + except KeyError as err: + raise ValueError( + f"Setpoint for ({device_id!r}, {setpoint_name!r}) not found in the" + " environment" + ) from err + setpoint_action_names.add(action_name) + + normalizer = self.env.action_normalizers.get(setpoint_name) + if normalizer is None: + raise ValueError(f"Normalizer not found for setpoint: {action_name!r}") + + setpoint_value = setpoint.setpoint_value + setpoint_min = normalizer.setpoint_min # min native value + setpoint_max = normalizer.setpoint_max # max native value + if not (setpoint_min <= setpoint_value <= setpoint_max): + if self._clip: + clipped_value = max(setpoint_min, min(setpoint_value, setpoint_max)) + self._guardrails_exceeded.append( + GuardrailsExceededRecord( + device_id=device_id, + setpoint_name=setpoint_name, + requested_value=setpoint_value, + setpoint_range=(setpoint_min, setpoint_max), + clipped_value=clipped_value, + ) + ) + setpoint.setpoint_value = clipped_value + else: + raise GuardrailsExceededError( + f"Value {setpoint_value} for setpoint ({device_id!r}," + f" {setpoint_name!r}) is outside expected range [{setpoint_min}," + f" {setpoint_max}]" + ) + + missing_action_names = set(self.env.action_names) - setpoint_action_names + if missing_action_names: + raise ValueError( + "The following setpoints are expected by the environment but are" + f" missing from the schema: {missing_action_names}" + ) + + return self + + @property + def sorted_setpoints(self) -> Sequence[output_schema.DeviceSetpoint]: + """The setpoints, in the same order as the environment's action names.""" + return sorted( + self.setpoints, + key=lambda sp: self.env.action_names.index( + self.env.id_map[(sp.device_id, sp.setpoint_name)] + ), + ) + + def get_action_values(self) -> environment.NormalizedActionValues: + """Returns the normalized values used to step the `Environment`. + + Returns: + A list of normalized action values, sorted in the same order as the + environment's action names. + """ + normalized_values = [] + for sp in self.sorted_setpoints: + action_name = self.env.id_map[(sp.device_id, sp.setpoint_name)] + normalizer = self.env.action_normalizers.get(sp.setpoint_name) + if normalizer is None: + raise ValueError(f"No normalizer found for setpoint: {action_name!r}.") + normalized_values.append(normalizer.agent_value(sp.setpoint_value)) + return normalized_values + + def get_action(self) -> SteppableActionType: + """Returns the action used to step the environment.""" + return self.get_action_values() + + @property + def setpoint_records(self) -> list[dict[str, Any]]: + """The setpoints as a list of records (dictionaries).""" + return [ + { + "timestamp": self.timestamp, + "validity_interval": self.validity_interval, + "justification": self.justification, + "action_name": self.env.id_map[(sp.device_id, sp.setpoint_name)], + "device_id": sp.device_id, + "setpoint_name": sp.setpoint_name, + "setpoint_value": sp.setpoint_value, + "setpoint_justification": sp.justification, + } + for sp in self.sorted_setpoints + ] + + @property + def setpoints_df(self) -> pd.DataFrame: + """The setpoints as a pandas DataFrame.""" + return pd.DataFrame(self.setpoint_records) + + @property + def flattened_setpoints_record(self) -> dict[str, Any]: + """A flattened dictionary of setpoint records. No nesting. + + The dictionary has keys for each action_name and setpoint value, + and a second set of keys for each action_name and setpoint justification. + """ + record = { + "timestamp": self.timestamp, + "validity_interval": self.validity_interval, + "justification": self.justification, + } + for sp in self.sorted_setpoints: + action_name = self.env.id_map[(sp.device_id, sp.setpoint_name)] + record[action_name] = sp.setpoint_value + record[f"{action_name}_justification"] = sp.justification + return record + + +class HybridActionContext(ActionContext, Steppable): + """A `SetpointsAction` with awareness of the environment. + + This class should be used in conjunction with a `HybridActionEnvironment`. + """ + + _env: hybrid_action_environment.HybridActionEnvironment + + def __init__( + self, + env: hybrid_action_environment.HybridActionEnvironment, + *, + clip: bool = True, + **kwargs, + ): + """Initializes the instance. + + Args: + env: The hybrid action environment to use for validation. + clip: Governs the behavior when an agent requests a setpoint value that is + outside of the valid range. If `True`, clips the setpoint values to the + bounds of the valid range, and logs a record of the error, but does not + halt execution. This is the default behavior. Otherwise, if `False`, + will raise a `GuardrailsExceededError` and halt execution. + **kwargs: Arguments to pass to initialize the `SetpointsAction` schema. + + Raises: + GuardrailsExceededError: If `clip` is `False` and any setpoint value is + outside the valid range defined by the environment's normalizers. + """ + super().__init__(env=env, clip=clip, **kwargs) + + def get_hybrid_action(self) -> hybrid_action_environment.HybridAction: + """Returns the hybrid action used to step a `HybridActionEnvironment`.""" + return self.env.convert_to_hybrid(self.get_action_values()) + + def get_action(self) -> SteppableActionType: + """Returns the action used to step the environment.""" + return self.get_hybrid_action() + + +def create_action_context_model( + custom_intervals: Sequence[int], + *, + hybrid: bool = True, +) -> type[ActionContext]: + """Creates an action context model class, using custom validity intervals. + + Args: + custom_intervals: A list of intervals in minutes. Represents the range of + possible options the LLM has to choose from. + hybrid: Whether to create a hybrid action context model class. + + Returns: + A Pydantic model class based on `ActionContext`, but defined using the + provided set of custom validity intervals. + """ + custom_intervals = sorted(set(custom_intervals)) + ValidityIntervalOptions = Literal[*custom_intervals] # pytype: disable=invalid-annotation # pydantic needs it this way + + fields = { + "validity_interval": ( + ValidityIntervalOptions, + pydantic.Field( + description=output_schema.VALIDITY_INTERVAL_DESCRIPTION + ), + ) + } + + if hybrid: + base_class = HybridActionContext + model_name = "HybridActionContextWithCustomInterval" + else: + base_class = ActionContext + model_name = "ActionContextWithCustomInterval" + + model = pydantic.create_model( + model_name, + **fields, + __base__=base_class, + ) + model.__doc__ = base_class.__doc__ + return model diff --git a/smart_control/llm/schema/action_context_test.py b/smart_control/llm/schema/action_context_test.py new file mode 100644 index 00000000..a62621d6 --- /dev/null +++ b/smart_control/llm/schema/action_context_test.py @@ -0,0 +1,576 @@ +from typing import get_args +from unittest import mock + +from absl.testing import absltest +import pandas as pd +import pydantic +from smart_buildings.smart_control.environment import conftest as env_conftest +from smart_buildings.smart_control.environment import environment +from smart_buildings.smart_control.environment import hybrid_action_environment +from smart_buildings.smart_control.llm.schema import action_context +from smart_buildings.smart_control.llm.schema import conftest as schema_conftest +from smart_buildings.smart_control.llm.schema import output_schema_test + + +class ActionContextTest(output_schema_test.ActionTest): + + def setUp(self): + super().setUp() + self.env = env_conftest.create_environment(layout=env_conftest.DEMO_LAYOUT) + self.action_ctx = schema_conftest.create_action_context(env=self.env) + + def test_initialization(self): + self.assertIsInstance(self.action_ctx, action_context.ActionContext) + + def test_env(self): + self.assertIsInstance(self.action_ctx.env, environment.Environment) + + def test_clip(self): + self.assertTrue(self.action_ctx.clip) + + def test_guardrails_exceeded(self): + self.assertEmpty(self.action_ctx.guardrails_exceeded) + + def test_sorted_setpoints(self): + names_from_setpoints = [ + (sp.device_id, sp.setpoint_name) + for sp in self.action_ctx.sorted_setpoints + ] + names_from_env = [ + self.env.id_map.inv[action_name] + for action_name in self.env.action_names + ] + self.assertEqual(names_from_setpoints, names_from_env) + + def test_get_action_values(self): + self.assertEqual(self.action_ctx.get_action_values(), [-1.0, -1.0, -1.0]) + + def test_get_action_values_normalizer_not_found_raises(self): + with mock.patch.dict(self.env.action_normalizers, {}, clear=True): + with self.assertRaisesRegex( + ValueError, + "No normalizer found for setpoint:" + " 'air_handler_1_supply_air_heating_temperature_setpoint'.", + ): + self.action_ctx.get_action_values() + + def test_get_action_values_device_id_not_found_raises(self): + self.action_ctx.setpoints[0].device_id = "OOPS" + with self.assertRaisesRegex( + KeyError, "\\('OOPS', 'supply_air_heating_temperature_setpoint'\\)" + ): + self.action_ctx.get_action_values() + + def test_get_action_values_setpoint_name_not_found_raises(self): + self.action_ctx.setpoints[0].setpoint_name = "OOPS" + with self.assertRaisesRegex(KeyError, "\\('air_handler_1', 'OOPS'\\)"): + self.action_ctx.get_action_values() + + def test_setpoints_df(self): + df = self.action_ctx.setpoints_df + expected_df = pd.DataFrame([ + { + "timestamp": "2025-01-01 12:00:00", + "validity_interval": 60, + "justification": "These are my overall goals.", + "action_name": ( + "air_handler_1_supply_air_heating_temperature_setpoint" + ), + "device_id": "air_handler_1", + "setpoint_name": "supply_air_heating_temperature_setpoint", + "setpoint_value": 285.0, + "setpoint_justification": "To cool the air.", + }, + { + "timestamp": "2025-01-01 12:00:00", + "validity_interval": 60, + "justification": "These are my overall goals.", + "action_name": "boiler_1_supply_water_setpoint", + "device_id": "boiler_1", + "setpoint_name": "supply_water_setpoint", + "setpoint_value": 310.0, + "setpoint_justification": "To heat the water.", + }, + { + "timestamp": "2025-01-01 12:00:00", + "validity_interval": 60, + "justification": "These are my overall goals.", + "action_name": ( + "air_handler_2_supply_air_heating_temperature_setpoint" + ), + "device_id": "air_handler_2", + "setpoint_name": "supply_air_heating_temperature_setpoint", + "setpoint_value": 285.0, + "setpoint_justification": "To cool the air.", + }, + ]) + pd.testing.assert_frame_equal(df, expected_df) + + def test_flattened_setpoints_record(self): + record = self.action_ctx.flattened_setpoints_record + expected_record = { + "timestamp": "2025-01-01 12:00:00", + "validity_interval": 60, + "justification": "These are my overall goals.", + "air_handler_1_supply_air_heating_temperature_setpoint": 285.0, + "air_handler_1_supply_air_heating_temperature_setpoint_justification": ( + "To cool the air." + ), + "boiler_1_supply_water_setpoint": 310.0, + "boiler_1_supply_water_setpoint_justification": "To heat the water.", + "air_handler_2_supply_air_heating_temperature_setpoint": 285.0, + "air_handler_2_supply_air_heating_temperature_setpoint_justification": ( + "To cool the air." + ), + } + self.assertDictEqual(record, expected_record) + + +class HybridActionContextTest(output_schema_test.HybridActionTest): + + def setUp(self): + super().setUp() + self.env = env_conftest.create_hybrid_action_environment( + layout=env_conftest.DEMO_LAYOUT + ) + self.action_ctx = schema_conftest.create_hybrid_action_context(env=self.env) + + def test_initialization(self): + self.assertIsInstance(self.action_ctx, action_context.HybridActionContext) + + def test_env(self): + self.assertIsInstance( + self.action_ctx.env, + hybrid_action_environment.HybridActionEnvironment, + ) + + def test_clip(self): + self.assertTrue(self.action_ctx.clip) + + def test_guardrails_exceeded(self): + self.assertEmpty(self.action_ctx.guardrails_exceeded) + + def test_get_action_values(self): + self.assertEqual( + self.action_ctx.get_action_values(), [-1.0, 1.0, -1.0, 1.0, -1.0, 1.0] + ) + + def test_get_action_values_normalizer_not_found_raises(self): + with mock.patch.dict(self.env.action_normalizers, {}, clear=True): + with self.assertRaisesRegex( + ValueError, + "No normalizer found for setpoint:" + " 'air_handler_1_supply_air_heating_temperature_setpoint'.", + ): + self.action_ctx.get_action_values() + + def test_get_action_values_device_id_not_found_raises(self): + self.action_ctx.setpoints[0].device_id = "OOPS" + with self.assertRaisesRegex( + KeyError, "\\('OOPS', 'supervisor_run_command'\\)" + ): + self.action_ctx.get_action_values() + + def test_get_action_values_setpoint_name_not_found_raises(self): + self.action_ctx.setpoints[0].setpoint_name = "OOPS" + with self.assertRaisesRegex(KeyError, "\\('air_handler_1', 'OOPS'\\)"): + self.action_ctx.get_action_values() + + def test_get_hybrid_action(self): + self.assertEqual( + self.action_ctx.get_hybrid_action(), + { + "continuous_action": [-1.0, -1.0, -1.0], + "discrete_action": [1.0, 1.0, 1.0], + }, + ) + + def test_setpoints_df(self): + df = self.action_ctx.setpoints_df + expected_df = pd.DataFrame([ + { + "timestamp": "2025-01-01 12:00:00", + "validity_interval": 60, + "justification": "These are my overall goals.", + "action_name": ( + "air_handler_1_supply_air_heating_temperature_setpoint" + ), + "device_id": "air_handler_1", + "setpoint_name": "supply_air_heating_temperature_setpoint", + "setpoint_value": 285.0, + "setpoint_justification": "To cool the air.", + }, + { + "timestamp": "2025-01-01 12:00:00", + "validity_interval": 60, + "justification": "These are my overall goals.", + "action_name": "air_handler_1_supervisor_run_command", + "device_id": "air_handler_1", + "setpoint_name": "supervisor_run_command", + "setpoint_value": 1.0, + "setpoint_justification": "To turn the device on.", + }, + { + "timestamp": "2025-01-01 12:00:00", + "validity_interval": 60, + "justification": "These are my overall goals.", + "action_name": "boiler_1_supply_water_setpoint", + "device_id": "boiler_1", + "setpoint_name": "supply_water_setpoint", + "setpoint_value": 310.0, + "setpoint_justification": "To heat the water.", + }, + { + "timestamp": "2025-01-01 12:00:00", + "validity_interval": 60, + "justification": "These are my overall goals.", + "action_name": "boiler_1_supervisor_run_command", + "device_id": "boiler_1", + "setpoint_name": "supervisor_run_command", + "setpoint_value": 1.0, + "setpoint_justification": "To turn the device on.", + }, + { + "timestamp": "2025-01-01 12:00:00", + "validity_interval": 60, + "justification": "These are my overall goals.", + "action_name": ( + "air_handler_2_supply_air_heating_temperature_setpoint" + ), + "device_id": "air_handler_2", + "setpoint_name": "supply_air_heating_temperature_setpoint", + "setpoint_value": 285.0, + "setpoint_justification": "To cool the air.", + }, + { + "timestamp": "2025-01-01 12:00:00", + "validity_interval": 60, + "justification": "These are my overall goals.", + "action_name": "air_handler_2_supervisor_run_command", + "device_id": "air_handler_2", + "setpoint_name": "supervisor_run_command", + "setpoint_value": 1.0, + "setpoint_justification": "To turn the device on.", + }, + ]) + pd.testing.assert_frame_equal(df, expected_df) + + def test_flattened_setpoints_record(self): + record = self.action_ctx.flattened_setpoints_record + expected_record = { + "timestamp": "2025-01-01 12:00:00", + "validity_interval": 60, + "justification": "These are my overall goals.", + "air_handler_1_supervisor_run_command": 1.0, + "air_handler_1_supervisor_run_command_justification": ( + "To turn the device on." + ), + "air_handler_2_supervisor_run_command": 1.0, + "air_handler_2_supervisor_run_command_justification": ( + "To turn the device on." + ), + "boiler_1_supervisor_run_command": 1.0, + "boiler_1_supervisor_run_command_justification": ( + "To turn the device on." + ), + "air_handler_1_supply_air_heating_temperature_setpoint": 285.0, + "air_handler_1_supply_air_heating_temperature_setpoint_justification": ( + "To cool the air." + ), + "air_handler_2_supply_air_heating_temperature_setpoint": 285.0, + "air_handler_2_supply_air_heating_temperature_setpoint_justification": ( + "To cool the air." + ), + "boiler_1_supply_water_setpoint": 310.0, + "boiler_1_supply_water_setpoint_justification": "To heat the water.", + } + self.assertDictEqual(record, expected_record) + + +# +# CONSTRUCTOR TESTS +# + + +class ActionContextFromJsonTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.env = env_conftest.create_environment(layout=env_conftest.DEMO_LAYOUT) + self.schema = schema_conftest.create_action() + self.json = self.schema.model_dump_json() + self.action_ctx = action_context.ActionContext.from_json( + txt=self.json, env=self.env, clip=True + ) + + def test_json(self): + self.assertIsInstance(self.json, str) + + def test_initialization(self): + self.assertIsInstance(self.action_ctx, action_context.ActionContext) + + def test_extra_attributes(self): + self.assertIs(self.action_ctx.env, self.env) + self.assertTrue(self.action_ctx.clip) + self.assertEmpty(self.action_ctx.guardrails_exceeded) + + def test_schema_contents(self): + with self.subTest("timestamp"): + self.assertEqual(self.action_ctx.timestamp, self.schema.timestamp) + + with self.subTest("justification"): + self.assertEqual(self.action_ctx.justification, self.schema.justification) + + with self.subTest("validity_interval"): + self.assertEqual( + self.action_ctx.validity_interval, self.schema.validity_interval + ) + + with self.subTest("setpoints"): + self.assertEqual(self.action_ctx.setpoints, self.schema.setpoints) + + +class HybridActionContextFromJsonTest(ActionContextFromJsonTest): + + def setUp(self): + super().setUp() + self.env = env_conftest.create_hybrid_action_environment( + layout=env_conftest.DEMO_LAYOUT + ) + self.schema = schema_conftest.create_hybrid_action() + self.json = self.schema.model_dump_json() + self.action_ctx = action_context.HybridActionContext.from_json( + txt=self.json, env=self.env, clip=True + ) + + def test_initialization(self): + self.assertIsInstance(self.action_ctx, action_context.HybridActionContext) + + +# +# GUARDRAILS / VALIDATION TESTS +# + + +class ActionContextGuardrailsTest(absltest.TestCase): + """Tests for guardrails behavior when clipping is disabled.""" + + CLIPPING_ENABLED = False + + def setUp(self): + super().setUp() + self.env = env_conftest.create_environment(layout=env_conftest.DEMO_LAYOUT) + self.schema = schema_conftest.create_action() + self.clip = self.CLIPPING_ENABLED + + def test_device_id_not_found_raises(self): + self.schema.setpoints[0].device_id = "OOPS" + with self.assertRaisesRegex( + pydantic.ValidationError, + "Setpoint for \\('OOPS', 'supply_air_heating_temperature_setpoint'\\)" + " not found in the environment", + ): + action_context.ActionContext( + env=self.env, clip=self.clip, **self.schema.model_dump() + ) + + def test_setpoint_name_not_found_raises(self): + self.schema.setpoints[0].setpoint_name = "OOPS" + with self.assertRaisesRegex( + pydantic.ValidationError, + "Setpoint for \\('air_handler_1', 'OOPS'\\) not found in the" + " environment", + ): + action_context.ActionContext( + env=self.env, clip=self.clip, **self.schema.model_dump() + ) + + def test_normalizer_not_found_raises(self): + with mock.patch.dict(self.env.action_normalizers, {}, clear=True): + with self.assertRaisesRegex( + pydantic.ValidationError, + "Normalizer not found for setpoint:" + " 'air_handler_1_supply_air_heating_temperature_setpoint'.", + ): + action_context.ActionContext( + env=self.env, clip=self.clip, **self.schema.model_dump() + ) + + def test_missing_setpoint_raises(self): + self.schema.setpoints.pop() + with self.assertRaisesRegex( + pydantic.ValidationError, + "The following setpoints are expected by the environment but are" + " missing from the schema:.*'boiler_1_supply_water_setpoint'", + ): + action_context.ActionContext( + env=self.env, clip=self.clip, **self.schema.model_dump() + ) + + # TESTS WHERE CLIPPING OPTION IS RELEVANT + + def test_clipping_option(self): + self.assertFalse(self.clip) + + def test_setpoint_value_above_range(self): + self.schema.setpoints[0].setpoint_value = 300.0 # Above range + with self.assertRaisesRegex( + pydantic.ValidationError, + " Value 300.0 for setpoint \\('air_handler_1'.*" + "'supply_air_heating_temperature_setpoint'\\) is outside expected" + " range \\[285\\.0, 295\\.0\\]", + ): + action_context.ActionContext( + env=self.env, clip=self.clip, **self.schema.model_dump() + ) + + def test_setpoint_value_below_range(self): + self.schema.setpoints[0].setpoint_value = 200.0 # Below range + with self.assertRaisesRegex( + pydantic.ValidationError, + " Value 200.0 for setpoint \\('air_handler_1'.*" + "'supply_air_heating_temperature_setpoint'\\) is outside expected" + " range \\[285\\.0, 295\\.0\\]", + ): + action_context.ActionContext( + env=self.env, clip=self.clip, **self.schema.model_dump() + ) + + +class ActionContextGuardrailsClippingTest(ActionContextGuardrailsTest): + """Tests for guardrails behavior when clipping is enabled.""" + + CLIPPING_ENABLED = True + + def test_clipping_option(self): + self.assertTrue(self.clip) + + def test_setpoint_value_above_range(self): + self.schema.setpoints[0].setpoint_value = 300.0 # Above range + action_ctx = action_context.ActionContext( + env=self.env, clip=self.clip, **self.schema.model_dump() + ) + + with self.subTest(name="clips_value_to_max"): + self.assertEqual(action_ctx.setpoints[0].setpoint_value, 295.0) # Max + + with self.subTest(name="logs_guardrails_error"): + self.assertLen(action_ctx.guardrails_exceeded, 1) + self.assertEqual( + action_ctx.guardrails_exceeded[0], + action_context.GuardrailsExceededRecord( + device_id="air_handler_1", + setpoint_name="supply_air_heating_temperature_setpoint", + requested_value=300.0, + setpoint_range=(285.0, 295.0), + clipped_value=295.0, + ), + ) + + def test_setpoint_value_below_range(self): + self.schema.setpoints[0].setpoint_value = 200.0 # Below range + action_ctx = action_context.ActionContext( + env=self.env, clip=self.clip, **self.schema.model_dump() + ) + + with self.subTest(name="clips_value_to_min"): + self.assertEqual(action_ctx.setpoints[0].setpoint_value, 285.0) # Min + + with self.subTest(name="logs_guardrails_error"): + self.assertLen(action_ctx.guardrails_exceeded, 1) + self.assertEqual( + action_ctx.guardrails_exceeded[0], + action_context.GuardrailsExceededRecord( + device_id="air_handler_1", + setpoint_name="supply_air_heating_temperature_setpoint", + requested_value=200.0, + setpoint_range=(285.0, 295.0), + clipped_value=285.0, + ), + ) + + +# +# CUSTOM VALIDITY INTERVALS +# + + +class ActionContextWithCustomValidityIntervalsTest(absltest.TestCase): + + IS_HYBRID = False + + def setUp(self): + super().setUp() + self.custom_intervals = [15, 30, 45, 60] + self.schema = action_context.create_action_context_model( + custom_intervals=self.custom_intervals, + hybrid=self.IS_HYBRID, + ) + + def test_initialization(self): + self.assertTrue(issubclass(self.schema, action_context.ActionContext)) + self.assertFalse( + issubclass(self.schema, action_context.HybridActionContext) + ) + + def test_validity_interval_options(self): + self.assertCountEqual( + get_args(self.schema.__annotations__["validity_interval"]), + self.custom_intervals, + ) + + +class HybridActionContextWithCustomValidityIntervalsTest( + ActionContextWithCustomValidityIntervalsTest +): + + IS_HYBRID = True + + def test_initialization(self): + self.assertTrue(issubclass(self.schema, action_context.ActionContext)) + self.assertTrue(issubclass(self.schema, action_context.HybridActionContext)) + + +# +# FACTORY FUNCTION TESTS +# + + +class ActionContextFactoryTest(absltest.TestCase): + + def test_defaults(self): + action_ctx = schema_conftest.create_action_context() + self.assertIsInstance(action_ctx, action_context.ActionContext) + + def test_overrides(self): + env = env_conftest.create_environment(layout=env_conftest.DEMO_LAYOUT) + action = schema_conftest.create_action() + action.justification = "Custom justification." + + action_ctx = schema_conftest.create_action_context(env=env, action=action) + self.assertIsInstance(action_ctx, action_context.ActionContext) + self.assertEqual(action_ctx.justification, "Custom justification.") + + +class HybridActionContextFactoryTest(ActionContextFactoryTest): + + def test_defaults(self): + action_ctx = schema_conftest.create_hybrid_action_context() + self.assertIsInstance(action_ctx, action_context.HybridActionContext) + + def test_overrides(self): + env = env_conftest.create_hybrid_action_environment( + layout=env_conftest.DEMO_LAYOUT + ) + action = schema_conftest.create_hybrid_action() + action.justification = "Custom justification." + + action_ctx = schema_conftest.create_hybrid_action_context( + env=env, action=action + ) + self.assertIsInstance(action_ctx, action_context.HybridActionContext) + self.assertEqual(action_ctx.justification, "Custom justification.") + + +if __name__ == "__main__": + absltest.main() diff --git a/smart_control/llm/schema/conftest.py b/smart_control/llm/schema/conftest.py new file mode 100644 index 00000000..09034555 --- /dev/null +++ b/smart_control/llm/schema/conftest.py @@ -0,0 +1,386 @@ +"""Test helpers for LLM prompts and output schema models. + +Contains objects for representing the LLM's response in string format, as well +as the corresponding Pydantic models parsed from those strings. + +Provides actions for both the continuous and hybrid action environments. +""" + +from collections.abc import Sequence +import json +import re +import textwrap +from typing import Any + +from smart_buildings.smart_control.environment import conftest as env_conftest +from smart_buildings.smart_control.environment import environment +from smart_buildings.smart_control.environment import hybrid_action_environment +from smart_buildings.smart_control.llm.schema import action_context +from smart_buildings.smart_control.llm.schema import output_schema + +DISCRETE_ACTION_COMMAND = hybrid_action_environment.DISCRETE_ACTION_COMMAND + +DEFAULT_VALIDITY_INTERVALS = output_schema.DEFAULT_VALIDITY_INTERVALS + +EXAMPLE_TIMESTAMP = "2025-01-01 12:00:00" +EXAMPLE_JUSTIFICATION = "These are my overall goals." +EXAMPLE_DEVICE_JUSTIFICATION = "The reason for choosing this setpoint value." + + +def parse_instructions_schema(instructions: str) -> dict[str, Any] | None: + """Parses a string containing a Pydantic schema, returns the schema data.""" + instructions = textwrap.dedent(instructions).strip() + match = re.search(r"```\n({.*})\n```", instructions, re.DOTALL) + if match: + json_string = match.group(1) + try: + schema = json.loads(json_string) + return schema + except json.JSONDecodeError: + return None + return None + + +# +# DEVICE SETPOINTS +# + + +def create_supply_air_heating_temperature_setpoint( + device_id: str = "air_handler_0", + setpoint_value: float = 285.0, + justification: str = "To cool the air.", +) -> output_schema.DeviceSetpoint: + """Creates a supply air heating temperature setpoint for a specific device.""" + return output_schema.DeviceSetpoint( + device_id=device_id, + setpoint_name="supply_air_heating_temperature_setpoint", + setpoint_value=setpoint_value, + justification=justification, + ) + + +def create_supply_water_setpoint( + device_id: str = "boiler_0", + setpoint_value: float = 310.0, + justification: str = "To heat the water.", +) -> output_schema.DeviceSetpoint: + """Creates a supply water temperature setpoint for a specific device.""" + return output_schema.DeviceSetpoint( + device_id=device_id, + setpoint_name="supply_water_setpoint", + setpoint_value=setpoint_value, + justification=justification, + ) + + +def create_supervisor_run_command_setpoint( + device_id: str = "air_handler_0", + setpoint_value: float = 1, + justification: str = "To turn the device on.", +) -> output_schema.DeviceSetpoint: + """Creates a supervisor run command setpoint for a specific device.""" + return output_schema.DeviceSetpoint( + device_id=device_id, + setpoint_name=DISCRETE_ACTION_COMMAND, + setpoint_value=setpoint_value, + justification=justification, + ) + + +# +# ACTIONS (CONTINUOUS) +# + + +def create_action_response( + ahu_1_supply_air_temp: float = 285.0, # -1.0 (bottom of range) + ahu_2_supply_air_temp: float = 295.0, # 1.0 (top of range) + hws_supply_water_temp: float = 330.0, # 0.0 (middle of range) + empty_setpoints: bool = False, + missing_setpoint: bool = False, + missing_field: bool = False, + validity_interval: int = 60, +) -> str: + """Creates an action response for the continuous action environment. + + Provides convenience arguments for creating invalid responses. Only one of + these arguments (empty_setpoints, missing_setpoint, missing_field) should be + set to True at a time. + + Args: + ahu_1_supply_air_temp: The setpoint temp in Kelvin for AHU-1. + ahu_2_supply_air_temp: The setpoint temp in Kelvin for AHU-2. + hws_supply_water_temp: The setpoint temp in Kelvin for HWS. + empty_setpoints: Whether to remove all setpoints from the response, to make + it invalid. + missing_setpoint: Whether to remove a setpoint from the response, to make it + invalid (from the environment's perspective only). + missing_field: Whether to remove a field from a setpoint, to make it + invalid. + validity_interval: The selected validity interval, in minutes. + + Returns: + The action response as a JSON-formatted string. + """ + + action_data = { + "timestamp": EXAMPLE_TIMESTAMP, + "justification": EXAMPLE_JUSTIFICATION, + "validity_interval": validity_interval, + "setpoints": [ + { + "device_id": "air_handler_1", + "setpoint_name": "supply_air_heating_temperature_setpoint", + "setpoint_value": ahu_1_supply_air_temp, + "justification": EXAMPLE_DEVICE_JUSTIFICATION, + }, + { + "device_id": "boiler_1", + "setpoint_name": "supply_water_setpoint", + "setpoint_value": hws_supply_water_temp, + "justification": EXAMPLE_DEVICE_JUSTIFICATION, + }, + { + "device_id": "air_handler_2", + "setpoint_name": "supply_air_heating_temperature_setpoint", + "setpoint_value": ahu_2_supply_air_temp, + "justification": EXAMPLE_DEVICE_JUSTIFICATION, + }, + ], + } + + if sum([empty_setpoints, missing_setpoint, missing_field]) > 1: + raise ValueError( + "Only one of empty_setpoints, missing_setpoint, or missing_field can be" + " set to True at a time." + ) + + if missing_field: + del action_data["setpoints"][0]["justification"] + + if missing_setpoint: + del action_data["setpoints"][0] + + if empty_setpoints: + action_data["setpoints"] = [] + + # Convert data to a JSON-formatted string (to resemble the LLM's response): + return textwrap.dedent(json.dumps(action_data, indent=2)) + + +def create_action() -> output_schema.SetpointsAction: + return output_schema.SetpointsAction( + timestamp=EXAMPLE_TIMESTAMP, + justification=EXAMPLE_JUSTIFICATION, + validity_interval=60, + setpoints=[ + create_supply_air_heating_temperature_setpoint("air_handler_1"), + create_supply_air_heating_temperature_setpoint("air_handler_2"), + create_supply_water_setpoint("boiler_1"), + ], + ) + + +def create_action_context( + env: environment.Environment | None = None, + action: output_schema.SetpointsAction | None = None, +) -> action_context.ActionContext: + """Creates an action context for the continuous action environment.""" + env = env or env_conftest.create_environment(layout=env_conftest.DEMO_LAYOUT) + action = action or create_action() + return action_context.ActionContext(env=env, **action.model_dump()) + + +def create_action_with_custom_intervals( + validity_intervals: Sequence[int] = DEFAULT_VALIDITY_INTERVALS, + selected_interval: int = 60, +) -> output_schema.SetpointsAction: + """Creates a SetpointsAction with custom validity intervals. + + Args: + validity_intervals: The list of possible validity intervals in minutes. + selected_interval: The selected validity interval in minutes. + + Returns: + A SetpointsAction object with custom validity intervals. + """ + model_class = output_schema.create_action_model( + custom_intervals=validity_intervals + ) + + return model_class( + timestamp=EXAMPLE_TIMESTAMP, + justification=EXAMPLE_JUSTIFICATION, + validity_interval=selected_interval, + setpoints=[ + create_supply_air_heating_temperature_setpoint("air_handler_1"), + create_supply_air_heating_temperature_setpoint("air_handler_2"), + create_supply_water_setpoint("boiler_1"), + ], + ) + + +# +# ACTIONS (HYBRID) +# + + +def create_hybrid_action_response( + ahu_1_supply_air_temp: float = 285.0, # -1.0 (bottom of range) + ahu_2_supply_air_temp: float = 295.0, # 1.0 (top of range) + hws_supply_water_temp: float = 330.0, # 0.0 (middle of range) + ahu_1_run_command: int = 1, # ON + ahu_2_run_command: int = 1, # ON + hws_run_command: int = 1, # ON + empty_setpoints: bool = False, + missing_setpoint: bool = False, + missing_field: bool = False, + validity_interval: int = 60, +) -> str: + """Creates an action response for the hybrid action environment. + + Provides convenience arguments for creating invalid responses. Only one of + these arguments (empty_setpoints, missing_setpoint, missing_field) should be + set to True at a time. + + Args: + ahu_1_supply_air_temp: The setpoint temp in Kelvin for AHU-1. + ahu_2_supply_air_temp: The setpoint temp in Kelvin for AHU-2. + hws_supply_water_temp: The setpoint temp in Kelvin for HWS. + ahu_1_run_command: The run command for AHU-1. + ahu_2_run_command: The run command for AHU-2. + hws_run_command: The run command for HWS. + empty_setpoints: Whether to remove all setpoints from the response, to make + it invalid. + missing_setpoint: Whether to remove a setpoint from the response, to make it + invalid (from the environment's perspective only). + missing_field: Whether to remove a field from a setpoint, to make it + invalid. + validity_interval: The selected validity interval, in minutes. + + Returns: + The action response as a JSON-formatted string. + """ + + action_data = { + "timestamp": EXAMPLE_TIMESTAMP, + "validity_interval": validity_interval, + "justification": EXAMPLE_JUSTIFICATION, + "setpoints": [ + { + "device_id": "air_handler_1", + "setpoint_name": DISCRETE_ACTION_COMMAND, + "setpoint_value": ahu_1_run_command, + "justification": EXAMPLE_DEVICE_JUSTIFICATION, + }, + { + "device_id": "air_handler_2", + "setpoint_name": DISCRETE_ACTION_COMMAND, + "setpoint_value": ahu_2_run_command, + "justification": EXAMPLE_DEVICE_JUSTIFICATION, + }, + { + "device_id": "boiler_1", + "setpoint_name": DISCRETE_ACTION_COMMAND, + "setpoint_value": hws_run_command, + "justification": EXAMPLE_DEVICE_JUSTIFICATION, + }, + { + "device_id": "air_handler_1", + "setpoint_name": "supply_air_heating_temperature_setpoint", + "setpoint_value": ahu_1_supply_air_temp, + "justification": EXAMPLE_DEVICE_JUSTIFICATION, + }, + { + "device_id": "air_handler_2", + "setpoint_name": "supply_air_heating_temperature_setpoint", + "setpoint_value": ahu_2_supply_air_temp, + "justification": EXAMPLE_DEVICE_JUSTIFICATION, + }, + { + "device_id": "boiler_1", + "setpoint_name": "supply_water_setpoint", + "setpoint_value": hws_supply_water_temp, + "justification": EXAMPLE_DEVICE_JUSTIFICATION, + }, + ], + } + + if sum([empty_setpoints, missing_setpoint, missing_field]) > 1: + raise ValueError( + "Only one of empty_setpoints, missing_setpoint, or missing_field can be" + " set to True at a time." + ) + + if missing_setpoint: + del action_data["setpoints"][0] + + if missing_field: + del action_data["setpoints"][0]["justification"] + + if empty_setpoints: + action_data["setpoints"] = [] + + # Convert data to a JSON-formatted string (to resemble the LLM's response): + return textwrap.dedent(json.dumps(action_data, indent=2)) + + +def create_hybrid_action() -> output_schema.SetpointsAction: + return output_schema.SetpointsAction( + timestamp=EXAMPLE_TIMESTAMP, + justification=EXAMPLE_JUSTIFICATION, + validity_interval=60, + setpoints=[ + create_supervisor_run_command_setpoint("air_handler_1"), + create_supervisor_run_command_setpoint("air_handler_2"), + create_supervisor_run_command_setpoint("boiler_1"), + create_supply_air_heating_temperature_setpoint("air_handler_1"), + create_supply_air_heating_temperature_setpoint("air_handler_2"), + create_supply_water_setpoint("boiler_1"), + ], + ) + + +def create_hybrid_action_context( + env: hybrid_action_environment.HybridActionEnvironment | None = None, + action: output_schema.SetpointsAction | None = None, +) -> action_context.HybridActionContext: + """Creates an action context for the hybrid action environment.""" + env = env or env_conftest.create_hybrid_action_environment( + layout=env_conftest.DEMO_LAYOUT + ) + action = action or create_hybrid_action() + return action_context.HybridActionContext(env=env, **action.model_dump()) + + +def create_hybrid_action_with_custom_intervals( + validity_intervals: Sequence[int] = DEFAULT_VALIDITY_INTERVALS, + selected_interval: int = 60, +) -> output_schema.SetpointsAction: + """Creates a SetpointsAction with hybrid action and custom validity intervals. + + Args: + validity_intervals: The list of possible validity intervals in minutes. + selected_interval: The selected validity interval in minutes. + + Returns: + A SetpointsAction object with custom validity intervals. + """ + model_class = output_schema.create_action_model( + custom_intervals=validity_intervals + ) + + return model_class( + timestamp=EXAMPLE_TIMESTAMP, + justification=EXAMPLE_JUSTIFICATION, + validity_interval=selected_interval, + setpoints=[ + create_supervisor_run_command_setpoint("air_handler_1"), + create_supervisor_run_command_setpoint("air_handler_2"), + create_supervisor_run_command_setpoint("boiler_1"), + create_supply_air_heating_temperature_setpoint("air_handler_1"), + create_supply_air_heating_temperature_setpoint("air_handler_2"), + create_supply_water_setpoint("boiler_1"), + ], + ) diff --git a/smart_control/llm/schema/formatting_instructions_test.py b/smart_control/llm/schema/formatting_instructions_test.py new file mode 100644 index 00000000..4853ccd0 --- /dev/null +++ b/smart_control/llm/schema/formatting_instructions_test.py @@ -0,0 +1,201 @@ +"""Tests for formatting instructions produced by output schema models. + +When prompting the LLM, we use the Langchain output parser to automatically +generate formatting instructions to be included in the prompt. These +instructions are derived from the output schema model, and include a +description of the desired output, as well as the schema itself. +""" + +import textwrap + +from absl.testing import absltest +import langchain.output_parsers +from smart_buildings.smart_control.llm.schema import conftest +from smart_buildings.smart_control.llm.schema import output_schema + +PydanticOutputParser = langchain.output_parsers.PydanticOutputParser + + +class SchemaParserTest(absltest.TestCase): + """Tests for the schema parser helper function.""" + + def test_parse_instructions_schema(self): + instructions = textwrap.dedent(""" + ``` + { + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + } + ``` + """) + schema = conftest.parse_instructions_schema(instructions) + self.assertDictEqual( + schema, + { + 'type': 'object', + 'properties': { + 'name': { + 'type': 'string', + }, + }, + }, + ) + + def test_parse_instructions_schema_invalid(self): + instructions = 'oops' + schema = conftest.parse_instructions_schema(instructions) + self.assertIsNone(schema) + + def test_parse_instructions_schema_malformed_json(self): + instructions = textwrap.dedent(""" + ``` + { + "invalid": json + } + ``` + """) + schema = conftest.parse_instructions_schema(instructions) + self.assertIsNone(schema) + + +class BaseFormattingInstructionsTest: + """For testing formatting instructions produced by output schema models.""" + + MODEL_CLASS = None # to be set by subclasses + EXPECTED_INTERVALS = None # to be set by subclasses + + def setUp(self): + super().setUp() + self.model_class = self.MODEL_CLASS + self.parser = PydanticOutputParser(pydantic_object=self.model_class) + self.instructions = self.parser.get_format_instructions() + self.schema = conftest.parse_instructions_schema(self.instructions) + + def test_formatting_instructions(self): + self.assertIsInstance(self.instructions, str) + self.assertIn( + 'The output should be formatted as a JSON instance that conforms to the' + ' JSON schema below.', + self.instructions, + ) + + def test_schema(self): + self.assertIsInstance(self.schema, dict) + self.assertCountEqual( + self.schema.keys(), + ['$defs', 'description', 'properties', 'required'], + ) + + +class FormattingInstructionsTest(BaseFormattingInstructionsTest, absltest.TestCase): # pylint: disable=line-too-long + + MODEL_CLASS = output_schema.SetpointsAction + EXPECTED_INTERVALS = list(output_schema.DEFAULT_VALIDITY_INTERVALS) + + def test_schema_required_fields(self): + self.assertCountEqual( + self.schema['required'], + [ + 'timestamp', + 'justification', + 'setpoints', + 'validity_interval', + ], + ) + + def test_schema_properties(self): + self.assertDictEqual( + self.schema['properties'], + { + 'setpoints': { + 'description': 'A list of setpoints.', + 'items': {'$ref': '#/$defs/DeviceSetpoint'}, + 'title': 'Setpoints', + 'type': 'array', + }, + 'timestamp': { + 'description': output_schema.TIMESTAMP_DESCRIPTION, + 'title': 'Timestamp', + 'type': 'string', + }, + 'justification': { + 'description': output_schema.JUSTIFICATION_DESCRIPTION, + 'title': 'Justification', + 'type': 'string', + }, + 'validity_interval': { + 'description': output_schema.VALIDITY_INTERVAL_DESCRIPTION, + 'enum': self.EXPECTED_INTERVALS, + 'title': 'Validity Interval', + 'type': 'integer', + }, + }, + ) + + def test_schema_defs(self): + self.assertListEqual(list(self.schema['$defs'].keys()), ['DeviceSetpoint']) + + schema_def = self.schema['$defs']['DeviceSetpoint'] + expected = { + 'description': ( + 'A single device setpoint.\n\nA device is uniquely identified by' + ' a composite key consisting of the device\nidentifier and the' + ' setpoint name.\n\nAttributes:\n device_id: The unique' + " identifier of the device (e.g. 'boiler-123-xyz').\n " + ' setpoint_name: The name of the setpoint (e.g.' + " 'supply_water_temperature').\n setpoint_value: The requested" + ' value to be set (e.g. 120.0).\n justification: The reason for' + ' choosing this specific device setting.' + ), + 'properties': { + 'device_id': { + 'description': 'The unique identifier of the device.', + 'title': 'Device Id', + 'type': 'string', + }, + 'setpoint_name': { + 'description': 'The name of the setpoint.', + 'title': 'Setpoint Name', + 'type': 'string', + }, + 'setpoint_value': { + 'description': 'The requested value to be set.', + 'title': 'Setpoint Value', + 'type': 'number', + }, + 'justification': { + 'description': ( + 'The reason for choosing this specific device setting.' + ), + 'title': 'Justification', + 'type': 'string', + }, + }, + 'required': [ + 'device_id', + 'setpoint_name', + 'setpoint_value', + 'justification', + ], + 'title': 'DeviceSetpoint', + 'type': 'object', + } + self.assertDictEqual(schema_def, expected) + + +class CustomIntervalInstructionsTest(FormattingInstructionsTest): + + CUSTOM_INTERVALS = [5, 10, 15, 20] + + MODEL_CLASS = output_schema.create_action_model( + custom_intervals=CUSTOM_INTERVALS + ) + EXPECTED_INTERVALS = CUSTOM_INTERVALS + + +if __name__ == '__main__': + absltest.main() diff --git a/smart_control/llm/schema/output_schema.py b/smart_control/llm/schema/output_schema.py new file mode 100644 index 00000000..37eb8141 --- /dev/null +++ b/smart_control/llm/schema/output_schema.py @@ -0,0 +1,156 @@ +"""Model classes for defining the structure of LLM responses. + +Includes models for the individual device setpoints, as well as action +models that contain any number of setpoints, and represent the full response +from an LLM, including other context such as the overall goals of the action. + +For flexibility, we define a base action model which uses a default +set of validity interval options, however we also provide a method for +creating a model that uses a custom set of validity interval options. + +The action model class itself is used to provide LLM formatting instructions, +which are derived from the model class schema. + +When the LLM responds, its response can be used to initialize the action model +class, and can be validated using the model class validator. +""" + +from collections.abc import Sequence +from typing import Literal, TypeAlias + +import pydantic +from smart_buildings.smart_control.utils import serialization + +Field = pydantic.Field + +DEFAULT_VALIDITY_INTERVALS = (5, 10, 15, 20, 30, 45, 60, 75, 90, 120) +DefaultValidityIntervalOptions: TypeAlias = Literal[*DEFAULT_VALIDITY_INTERVALS] # pytype: disable=invalid-annotation # pydantic needs it this way + +TIMESTAMP_DESCRIPTION = ( + "The time the action is taken, formatted as 'YYYY-MM-DD HH:MM:SS', assumed" + " to be in the building's local timezone." +) + +JUSTIFICATION_DESCRIPTION = ( + "The overall reason for taking this action. Includes a brief description" + " of why the action is justified, as well as the desired outcome of the" + " action as a whole." +) + +VALIDITY_INTERVAL_DESCRIPTION = ( + "The number of minutes the setpoints should remain in effect before" + " prompting for a new action." +) + + +class DeviceSetpoint(pydantic.BaseModel): + """A single device setpoint. + + A device is uniquely identified by a composite key consisting of the device + identifier and the setpoint name. + + Attributes: + device_id: The unique identifier of the device (e.g. 'boiler-123-xyz'). + setpoint_name: The name of the setpoint (e.g. 'supply_water_temperature'). + setpoint_value: The requested value to be set (e.g. 120.0). + justification: The reason for choosing this specific device setting. + """ + + device_id: str = Field(description="The unique identifier of the device.") + + setpoint_name: str = Field(description="The name of the setpoint.") + + setpoint_value: float = Field(description="The requested value to be set.") + + justification: str = Field( + description="The reason for choosing this specific device setting." + ) + + @property + def json_metadata(self) -> serialization.SerializableData: + """JSON-serializable metadata.""" + return self.model_dump() + + +class SetpointsAction(pydantic.BaseModel): + """A flexible action model for setting any number of setpoints. + + Attributes: + timestamp: The time the action is taken (in the building's local timezone). + justification: The overall reason for taking this action. Includes a brief + description of why the action is justified, as well as the desired + outcome of the action as a whole. + setpoints: A list of setpoints. + validity_interval: The amount of time in minutes the setpoints should remain + in effect before prompting for a new action. + """ + + timestamp: str = Field(description=TIMESTAMP_DESCRIPTION) + + justification: str = Field(description=JUSTIFICATION_DESCRIPTION) + + setpoints: list[DeviceSetpoint] = Field(description="A list of setpoints.") + + validity_interval: DefaultValidityIntervalOptions = Field( + description=VALIDITY_INTERVAL_DESCRIPTION + ) + + @pydantic.field_validator("setpoints") + @classmethod + def validate_setpoints( + cls, v: Sequence[DeviceSetpoint] + ) -> Sequence[DeviceSetpoint]: + """Ensures the setpoints are present.""" + if not v: + raise ValueError("The setpoints list cannot be empty.") + return v + + def find_setpoint( + self, device_id: str, setpoint_name: str + ) -> DeviceSetpoint | None: + """Returns the setpoint matching the given device id and setpoint name.""" + for setpoint in self.setpoints: + if ( + setpoint.device_id == device_id + and setpoint.setpoint_name == setpoint_name + ): + return setpoint + return None + + @property + def json_metadata(self) -> serialization.SerializableData: + """Serializable metadata.""" + return self.model_dump() + + +def create_action_model( + custom_intervals: Sequence[int], + model_name: str = "SetpointsActionWithCustomInterval", +) -> type[SetpointsAction]: + """Creates an agent action model class, using custom validity intervals. + + Args: + custom_intervals: A list of intervals in minutes. Represents the range of + possible options the LLM has to choose from. + model_name: The name of the action model class to be created. + + Returns: + A Pydantic model class based on `SetpointsAction`, but defined using the + provided set of custom validity intervals. + """ + custom_intervals = sorted(list(set(custom_intervals))) + ValidityIntervalOptions = Literal[*custom_intervals] # pytype: disable=invalid-annotation # pydantic needs it this way + + fields = { + "validity_interval": ( + ValidityIntervalOptions, + Field(description=VALIDITY_INTERVAL_DESCRIPTION), + ) + } + model = pydantic.create_model( + model_name, + **fields, + __base__=SetpointsAction, + ) + model.__doc__ = SetpointsAction.__doc__ + return model diff --git a/smart_control/llm/schema/output_schema_test.py b/smart_control/llm/schema/output_schema_test.py new file mode 100644 index 00000000..d11efd0f --- /dev/null +++ b/smart_control/llm/schema/output_schema_test.py @@ -0,0 +1,202 @@ +"""Tests for LLM response output schema models. + +These tests ensure the output schema models can be initialized. However, the +promptmaker actually uses them to generate formatting instructions. Tests for +that functionality are defined in the "formatting_instructions_test.py" file. +""" + +from typing import get_args + +from absl.testing import absltest +import pydantic +from smart_buildings.smart_control.llm.schema import conftest +from smart_buildings.smart_control.llm.schema import output_schema + +DeviceSetpoint = output_schema.DeviceSetpoint +SetpointsAction = output_schema.SetpointsAction + +EXAMPLE_TIMESTAMP = conftest.EXAMPLE_TIMESTAMP +EXAMPLE_JUSTIFICATION = conftest.EXAMPLE_JUSTIFICATION + + +# +# ACTIONS (CONTINUOUS) +# + + +class ActionValidationsTest(absltest.TestCase): + """Tests for Pydantic model validations, for continuous actions. + + This ensures the model will raise errors if required fields are missing, or if + the data is otherwise not in the expected format. + """ + + def setUp(self): + super().setUp() + self.creation_function = conftest.create_action_response + + def test_valid_setpoints(self): + response_text = self.creation_function() + action = SetpointsAction.model_validate_json(response_text) + self.assertIsInstance(action, SetpointsAction) + + def test_empty_setpoints_raises(self): + response_text = self.creation_function(empty_setpoints=True) + with self.assertRaisesRegex( + pydantic.ValidationError, "setpoints list cannot be empty" + ): + SetpointsAction.model_validate_json(response_text) + + def test_missing_setpoint_ok_beware(self): + # The schema doesn't know about which of the environment's setpoints are + # required. Those validations should happen at the environment level. + response_text = self.creation_function(missing_setpoint=True) + action = SetpointsAction.model_validate_json(response_text) + self.assertIsInstance(action, SetpointsAction) + + def test_missing_field_raises(self): + response_text = self.creation_function(missing_field=True) + with self.assertRaisesRegex(pydantic.ValidationError, "Field required"): + SetpointsAction.model_validate_json(response_text) + + +class ActionTest(absltest.TestCase): + """Tests for the basic action model that uses default validity intervals.""" + + def setUp(self): + super().setUp() + self.n_setpoints_expected = 3 + self.expected_setpoint_names = [ + "supply_air_heating_temperature_setpoint", + "supply_air_heating_temperature_setpoint", + "supply_water_setpoint", + ] + self.action = conftest.create_action() + + def test_validity_interval_options(self): + self.assertCountEqual( + get_args(self.action.__class__.__annotations__["validity_interval"]), + output_schema.DEFAULT_VALIDITY_INTERVALS, + ) + + def test_initialization(self): + self.assertIsInstance(self.action, SetpointsAction) + + def test_attributes(self): + with self.subTest("timestamp"): + self.assertEqual(self.action.timestamp, EXAMPLE_TIMESTAMP) + + with self.subTest("justification"): + self.assertEqual(self.action.justification, EXAMPLE_JUSTIFICATION) + + with self.subTest("validity_interval"): + self.assertEqual(self.action.validity_interval, 60) + + with self.subTest("setpoints"): + self.assertLen(self.action.setpoints, self.n_setpoints_expected) + + names = [setpoint.setpoint_name for setpoint in self.action.setpoints] + self.assertEqual(names, self.expected_setpoint_names) + + for i, setpoint in enumerate(self.action.setpoints): + with self.subTest(f"setpoint at index {i}"): + self.assertIsInstance(setpoint, DeviceSetpoint) + + # TESTS FOR FIND_SETPOINT METHOD: + + def test_find_setpoint_invalid_device_id(self): + setpoint = self.action.find_setpoint( + device_id="oops", setpoint_name="supply_water_setpoint" + ) + self.assertIsNone(setpoint) + + def test_find_setpoint_invalid_setpoint_name(self): + setpoint = self.action.find_setpoint( + device_id="boiler_0", setpoint_name="oops" + ) + self.assertIsNone(setpoint) + + def test_find_setpoint(self): + setpoint = self.action.find_setpoint( + device_id="boiler_1", setpoint_name="supply_water_setpoint" + ) + self.assertIsInstance(setpoint, DeviceSetpoint) + + with self.subTest("attributes"): + self.assertEqual(setpoint.device_id, "boiler_1") + self.assertEqual(setpoint.setpoint_name, "supply_water_setpoint") + self.assertEqual(setpoint.setpoint_value, 310.0) + + +class ActionWithCustomValidityIntervalsTest(ActionTest): + """Tests for the action model that uses custom validity intervals.""" + + def setUp(self): + super().setUp() + self.custom_intervals = [15, 30, 45, 60] + self.action = conftest.create_action_with_custom_intervals( + validity_intervals=self.custom_intervals, + selected_interval=60, + ) + + def test_validity_interval_options(self): + self.assertCountEqual( + get_args(self.action.__class__.__annotations__["validity_interval"]), + self.custom_intervals, + ) + + +# +# ACTIONS (HYBRID) +# + + +class HybridActionValidationsTest(ActionValidationsTest): + """Tests for Pydantic model validations, for hybrid actions. + + This ensures the model will raise errors if required fields are missing, or if + the data is otherwise not in the expected format. + """ + + def setUp(self): + super().setUp() + self.creation_function = conftest.create_hybrid_action_response + + +class HybridActionTest(ActionTest): + """Tests for the hybrid action model that uses default validity intervals.""" + + def setUp(self): + super().setUp() + self.n_setpoints_expected = 6 + self.expected_setpoint_names = [ + "supervisor_run_command", + "supervisor_run_command", + "supervisor_run_command", + "supply_air_heating_temperature_setpoint", + "supply_air_heating_temperature_setpoint", + "supply_water_setpoint", + ] + self.action = conftest.create_hybrid_action() + + +class HybridActionWithCustomValidityIntervalsTest(HybridActionTest): + """Tests for the hybrid action model that uses custom validity intervals.""" + + def setUp(self): + super().setUp() + self.custom_intervals = [15, 30, 45, 60] + self.action = conftest.create_hybrid_action_with_custom_intervals( + validity_intervals=self.custom_intervals, + selected_interval=60, + ) + + def test_validity_interval_options(self): + self.assertCountEqual( + get_args(self.action.__class__.__annotations__["validity_interval"]), + self.custom_intervals, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/smart_control/llm/services/conftest.py b/smart_control/llm/services/conftest.py new file mode 100644 index 00000000..101c71de --- /dev/null +++ b/smart_control/llm/services/conftest.py @@ -0,0 +1,125 @@ +"""Helpers for testing LLM services. + +The tests will implement mocked responses by default. + +To test the actual responses returned by the Gemini API, optionally set the +`TEST_GEMINI_SERVICE_LIVE` environment variable to 'true'. + +To test the actual responses returned by the Vertex AI API, optionally set the +`TEST_VERTEX_SERVICE_LIVE` environment variable to 'true'. +""" + +import os +from unittest import mock + +import dotenv +from google import genai +from google.auth import credentials +from smart_buildings.smart_control.llm.services import gemini_service +from smart_buildings.smart_control.llm.services import llm_service +from smart_buildings.smart_control.llm.services import vertex_service + +dotenv.load_dotenv() + +TEST_GEMINI_SERVICE_LIVE = bool( + os.getenv("TEST_GEMINI_SERVICE_LIVE", default="false").lower() == "true" +) +TEST_VERTEX_SERVICE_LIVE = bool( + os.getenv("TEST_VERTEX_SERVICE_LIVE", default="false").lower() == "true" +) + +SKIP_REASON = "Skip API Calls in tests by default." + +PROMPT_TEXT = "What year was America founded?" +RESPONSE_TEXT = ( + "The United States was founded in 1776 after the Declaration of " + "Independence." +) + + +class FakeLLMService(llm_service.BaseLLMService): + """Generic Fake LLM Service, used for testing.""" + + def __init__(self, response_text: str = RESPONSE_TEXT): + self._temperature = 0.0 + self._response_text = response_text + + @property + def model_name(self) -> str: + return "fake-model" + + @property + def temperature(self) -> float: + return self._temperature + + def get_response(self, prompt: str) -> str: + return self._response_text + + +def create_fake_llm_service( + response_text: str = RESPONSE_TEXT, +) -> FakeLLMService: + """Creates a fake version of a generic LLM Service. + + It will return the specified response text instead of making an API call. + + Args: + response_text: The text to return from the LLM Service. + + Returns: + A fake version of the LLM Service. + """ + return FakeLLMService(response_text=response_text) + + +def create_mock_gemini_service( + response_text: str = RESPONSE_TEXT, +) -> gemini_service.GeminiService: + """Creates a mock version of the Gemini Service. + + It will return the specified response text instead of making an API call. + + Args: + response_text: The text to return from the Gemini Service. If not provided, + a default response text will be used. + + Returns: + A mock version of the Gemini Service. + """ + # mocked dependencies: + client = mock.create_autospec(genai.Client, instance=True) + generate_content_response = mock.MagicMock() + generate_content_response.text = response_text + client.models.generate_content.return_value = generate_content_response + + # dependency injection: + return gemini_service.GeminiService(api_key="fake_api_key", client=client) + + +def create_mock_vertex_service( + response_text: str = RESPONSE_TEXT, # pylint: disable=unused-argument +) -> vertex_service.VertexAIService: + """Creates a mock version of the Vertex AI Service. + + It will return the specified response text instead of making an API call. + + Args: + response_text: The text to return from the Vertex AI Service. If not + provided, a default response text will be used. + + Returns: + A mock version of the Vertex AI Service. + """ + # mocked credentials: + creds = mock.create_autospec(credentials.Credentials, instance=True) + + # mocked client: + client = mock.create_autospec(genai.Client, instance=True) + generate_content_response = mock.MagicMock() + generate_content_response.text = response_text + client.models.generate_content.return_value = generate_content_response + + # dependency injection: + return vertex_service.VertexAIService( + project_id="not-a-real-project", credentials=creds, client=client + ) diff --git a/smart_control/llm/services/gemini_service.py b/smart_control/llm/services/gemini_service.py new file mode 100644 index 00000000..429d7952 --- /dev/null +++ b/smart_control/llm/services/gemini_service.py @@ -0,0 +1,218 @@ +# pylint: disable=line-too-long +r"""A Gemini service that uses the Gemini API directly, using an API key. + +Run with blaze: + +```shell +$ blaze run //third_party/py/smart_buildings/smart_control/llm/services:gemini_service_script +``` + +Run with python: + +```shell +$ python -m smart_buildings.smart_control.llm.services.gemini_service +``` + +Optional flags: + --gemini_api_key: API key to use for the Gemini API. + --gemini_model_temperature: The model temperature. + +Example: + +```shell +$ blaze run //third_party/py/smart_buildings/smart_control/llm/services:gemini_service_script -- \ + --gemini_api_key= --gemini_model_temperature=0.5 +``` +""" +# pylint: enable=line-too-long + +import abc +import getpass +import os +from typing import Any, Sequence + +from absl import app +from absl import flags +import dotenv +from google import genai +from smart_buildings.smart_control.llm.services import llm_service + +dotenv.load_dotenv() + +GEMINI_API_KEY = os.getenv('GEMINI_API_KEY') +MODEL_NAME = os.getenv('GEMINI_MODEL_NAME', default='gemini-2.0-flash') + +TEMPERATURE = 0.1 +TOP_P = 0.95 +TOP_K = 40 +MAX_OUTPUT_TOKENS = 1024 + + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + name='gemini_api_key', + default=None, + help='API key to use for the Gemini API.', +) + +flags.DEFINE_string( + name='gemini_model_temperature', default=None, help='The model temperature.' +) + + +class BaseGeminiService(llm_service.BaseLLMService, metaclass=abc.ABCMeta): + """A Gemini service interface allowing for flexible credentials approaches. + + Attributes: + model_name: The name of the Gemini model to use. + temperature: Controls the randomness of the output. Higher values mean more + random, lower values mean more deterministic. + top_p: Nucleus sampling parameter. Considers the smallest set of tokens + whose cumulative probability exceeds this value. + top_k: Top-k sampling parameter. Considers the top k most likely tokens at + each step. + max_output_tokens: The maximum number of tokens to generate. + generation_config: The generation config to use for the model. + api_key: The API key to use for the Gemini API. + client: The model client. + """ + + def __init__( + self, + model_name: str = MODEL_NAME, + temperature: float = TEMPERATURE, + top_p: float = TOP_P, + top_k: float = TOP_K, + max_output_tokens: int = MAX_OUTPUT_TOKENS, + ): + """Initializes a Gemini service interface. + + Args: + model_name: The name of the Gemini model to use. + temperature: Controls the randomness of the output. Higher values mean + more random, lower values mean more deterministic. + top_p: Nucleus sampling parameter. Considers the smallest set of tokens + whose cumulative probability exceeds this value. + top_k: Top-k sampling parameter. Considers the top k most likely tokens at + each step. + max_output_tokens: The maximum number of tokens to generate. + """ + self._model_name = model_name + self._temperature = temperature + self.top_p = top_p + self.top_k = top_k + self._max_output_tokens = max_output_tokens + + @property + def json_metadata(self) -> dict[str, Any]: + """Info to write into a JSON file. Needs to be serializable.""" + return { + 'type': self.__class__.__name__, + 'model_name': self.model_name, + 'generation_config': self.generation_config, + } + + @property + def model_name(self) -> str: + return self._model_name + + @property + def temperature(self) -> float: + return self._temperature + + @property + def max_output_tokens(self) -> int: + return self._max_output_tokens + + @property + @abc.abstractmethod + def client(self) -> genai.Client: + """Returns a client for the Gemini service.""" + + @property + def generation_config(self) -> dict[str, Any]: + return { + 'temperature': self.temperature, + 'top_p': self.top_p, + 'top_k': self.top_k, + 'max_output_tokens': self.max_output_tokens, + } + + def get_response(self, prompt: str) -> str: + """Returns the response from the Gemini model.""" + response = self.client.models.generate_content( + model=self.model_name, contents=prompt, config=self.generation_config + ) + return response.text + + +class GeminiService(BaseGeminiService): + """A Gemini service that uses the Gemini API directly, using an API key. + + Will use the `GEMINI_API_KEY` environment variable if provided. + """ + + def __init__( + self, + api_key: str = GEMINI_API_KEY, + model_name: str = MODEL_NAME, + temperature: float = TEMPERATURE, + top_p: float = TOP_P, + top_k: float = TOP_K, + max_output_tokens: int = MAX_OUTPUT_TOKENS, + client: genai.Client | None = None, + ): + """Initializes the Gemini service. + + Args: + api_key: The API key for the Gemini API. Will use the `GEMINI_API_KEY` + environment variable if provided. + model_name: The name of the Gemini model to use. + temperature: Controls the randomness of the output. Higher values mean + more random, lower values mean more deterministic. + top_p: Nucleus sampling parameter. Considers the smallest set of tokens + whose cumulative probability exceeds this value. + top_k: Top-k sampling parameter. Considers the top k most likely tokens at + each step. + max_output_tokens: The maximum number of tokens to generate. + client: An optional client to use for the Gemini API. Primarily used to + facilitate dependency injection during testing. If not provided, a new + client will be created using the specified api_key. + """ + super().__init__( + model_name=model_name, + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_output_tokens=max_output_tokens, + ) + + if not api_key: + raise ValueError( + 'Please provide an api_key, or set the GEMINI_API_KEY ' + 'environment variable.' + ) + self.api_key = api_key + + self._client = client or genai.Client(api_key=self.api_key) + + @property + def client(self) -> genai.Client: + return self._client + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + api_key = FLAGS.gemini_api_key or getpass.getpass('API Key: ') or GEMINI_API_KEY # pylint: disable=line-too-long + temp = FLAGS.gemini_model_temperature or input('Temperature: ') or TEMPERATURE + service = GeminiService(api_key=api_key, temperature=temp) + + user_prompt = input('Prompt: ') or 'When was America founded?' + print(service.get_response(user_prompt)) + + +if __name__ == '__main__': + app.run(main) diff --git a/smart_control/llm/services/gemini_service_test.py b/smart_control/llm/services/gemini_service_test.py new file mode 100644 index 00000000..385689a2 --- /dev/null +++ b/smart_control/llm/services/gemini_service_test.py @@ -0,0 +1,69 @@ +"""Tests for Gemini LLM service.""" + +import unittest +from unittest import mock + +from absl.testing import absltest +from google import genai +from smart_buildings.smart_control.llm.services import conftest +from smart_buildings.smart_control.llm.services.gemini_service import GeminiService # pylint: disable=g-importing-member + +FAKE_API_KEY = "not-a-real-api-key" + + +class GeminiServiceTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.service = GeminiService(api_key=FAKE_API_KEY) + + def test_api_key(self): + self.assertEqual(self.service.api_key, FAKE_API_KEY) + + def test_client(self): + self.assertIsInstance(self.service.client, genai.Client) + + def test_temperature(self): + self.assertEqual(self.service.temperature, 0.1) + + def test_generation_config(self): + config = self.service.generation_config + expected_config = { + "temperature": 0.1, + "top_p": 0.95, + "top_k": 40, + "max_output_tokens": 1024, + } + self.assertEqual(config, expected_config) + + @unittest.skipUnless(conftest.TEST_GEMINI_SERVICE_LIVE, conftest.SKIP_REASON) + def test_get_response(self): + response = self.service.get_response(conftest.PROMPT_TEXT) + self.assertIsInstance(response, str) + + def test_get_response_mocked(self): + client = mock.create_autospec(genai.Client, instance=True) + generate_content_response = mock.MagicMock() + generate_content_response.text = conftest.RESPONSE_TEXT + client.models.generate_content.return_value = generate_content_response + + service = GeminiService(api_key=FAKE_API_KEY, client=client) + response = service.get_response(conftest.PROMPT_TEXT) + self.assertIsInstance(response, str) + self.assertEqual(response, conftest.RESPONSE_TEXT) + + +class MockedGeminiServiceTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.service = conftest.create_mock_gemini_service() + + def test_get_response(self): + response = self.service.get_response(conftest.PROMPT_TEXT) + self.assertIsInstance(response, str) + self.assertEqual(response, conftest.RESPONSE_TEXT) + + +if __name__ == "__main__": + absltest.main() diff --git a/smart_control/llm/services/llm_service.py b/smart_control/llm/services/llm_service.py new file mode 100644 index 00000000..8038acc3 --- /dev/null +++ b/smart_control/llm/services/llm_service.py @@ -0,0 +1,31 @@ +"""Provides a generic interface for an LLM service.""" + +import abc +from typing import Any + + +class BaseLLMService(metaclass=abc.ABCMeta): + """Base class defining the common interface for an LLM service.""" + + @property + def json_metadata(self) -> dict[str, Any]: + """Info to write into a JSON file. Needs to be serializable.""" + return { + "type": self.__class__.__name__, + "model_name": self.model_name, + "temperature": self.temperature, + } + + @property + @abc.abstractmethod + def model_name(self) -> str: + """Returns the LLM model name.""" + + @property + @abc.abstractmethod + def temperature(self) -> float: + """Returns the LLM temperature.""" + + @abc.abstractmethod + def get_response(self, prompt: str) -> str | None: + """Returns the LLM's textual response from a given prompt.""" diff --git a/smart_control/llm/services/llm_service_test.py b/smart_control/llm/services/llm_service_test.py new file mode 100644 index 00000000..18e8ec3c --- /dev/null +++ b/smart_control/llm/services/llm_service_test.py @@ -0,0 +1,22 @@ +"""Tests for the Base LLM Service interface.""" + +from absl.testing import absltest +from smart_buildings.smart_control.llm.services import conftest + + +class LlmServiceTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.service = conftest.create_fake_llm_service() + + def test_temperature(self): + self.assertEqual(self.service.temperature, 0.0) + + def test_get_response(self): + response = self.service.get_response(conftest.PROMPT_TEXT) + self.assertEqual(response, conftest.RESPONSE_TEXT) + + +if __name__ == "__main__": + absltest.main() diff --git a/smart_control/llm/services/vertex_service.py b/smart_control/llm/services/vertex_service.py new file mode 100644 index 00000000..a6ea9461 --- /dev/null +++ b/smart_control/llm/services/vertex_service.py @@ -0,0 +1,93 @@ +"""A Gemini Service that uses the Vertex AI platform, and a GCP project.""" + +import os +from typing import Any + +import dotenv +from google import auth +from google import genai +from google.genai import types +from smart_buildings.smart_control.llm.services import gemini_service + +dotenv.load_dotenv() + +CREDENTIALS = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") +PROJECT_ID = os.getenv("VERTEX_AI_PROJECT_ID", default="smart-buildings-dev") +LOCATION = os.getenv("VERTEX_AI_LOCATION", default="us-central1") +MODEL_NAME = os.getenv("VERTEX_AI_MODEL_NAME", default="gemini-2.5-flash") + +SAFETY_DISABLED = ( + types.SafetySetting( + category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE" + ), + types.SafetySetting( + category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE" + ), + types.SafetySetting( + category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE" + ), + types.SafetySetting( + category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE" + ), +) + + +class VertexAIService(gemini_service.BaseGeminiService): + """A Gemini Service that uses Vertex AI and a GCP project. + + Attributes: + project_id: The GCP project ID to use for the Vertex AI service. + location: The GCP location to use for the Vertex AI service. + credentials: The credentials to use for the Vertex AI service. + safety_settings: The safety settings to use for the Vertex AI service. + client: The client to use for the Vertex AI service. + """ + + def __init__( + self, + project_id: str | None = PROJECT_ID, + location: str = LOCATION, + model_name: str = MODEL_NAME, + temperature: float = gemini_service.TEMPERATURE, + top_p: float = gemini_service.TOP_P, + top_k: float = gemini_service.TOP_K, + max_output_tokens: int = gemini_service.MAX_OUTPUT_TOKENS, + safety_settings: list[types.SafetySetting] | None = None, + credentials: auth.credentials.Credentials | None = None, + client: genai.Client | None = None, + ): + super().__init__( + model_name=model_name, + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_output_tokens=max_output_tokens, + ) + + self.project_id = project_id + self.location = location + self.credentials = credentials or CREDENTIALS + self.safety_settings = safety_settings or SAFETY_DISABLED + + # use default credentials if not provided: + if not self.credentials and not self.project_id: + self.credentials, self.project_id = auth.default() + + self._client = client or genai.Client( + vertexai=True, + project=self.project_id, + location=self.location, + credentials=self.credentials, + ) + + @property + def client(self) -> genai.Client: + """Returns a client for the Vertex AI service.""" + return self._client + + @property + def generation_config(self) -> dict[str, Any]: + """Returns the generation config for the Vertex AI service.""" + config = super().generation_config.copy() + config["safety_settings"] = self.safety_settings + return config diff --git a/smart_control/llm/services/vertex_service_test.py b/smart_control/llm/services/vertex_service_test.py new file mode 100644 index 00000000..4438c4f5 --- /dev/null +++ b/smart_control/llm/services/vertex_service_test.py @@ -0,0 +1,92 @@ +"""Tests for Vertex AI LLM service.""" + +import unittest +from unittest import mock + +from absl.testing import absltest +from google import genai +from google.auth import credentials +from smart_buildings.smart_control.llm.services import conftest +from smart_buildings.smart_control.llm.services import vertex_service + + +class VertexAIServiceTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.service = vertex_service.VertexAIService() + + def test_project_id(self): + self.assertEqual(self.service.project_id, 'smart-buildings-dev') + + def test_location(self): + self.assertEqual(self.service.location, 'us-central1') + + def test_model_name(self): + self.assertEqual(self.service.model_name, 'gemini-2.5-flash') + + def test_temperature(self): + self.assertEqual(self.service.temperature, 0.1) + + @unittest.skipUnless(conftest.TEST_VERTEX_SERVICE_LIVE, conftest.SKIP_REASON) + def test_credentials(self): + self.assertIsInstance(self.service.credentials, credentials.Credentials) + + def test_client(self): + self.assertIsInstance(self.service.client, genai.Client) + + def test_generation_config(self): + config = self.service.generation_config + self.assertIsInstance(config, dict) # or genai.types.GenerationConfig + expected_config = { + 'temperature': 0.1, + 'top_p': 0.95, + 'top_k': 40, + 'max_output_tokens': 1024, + 'safety_settings': vertex_service.SAFETY_DISABLED, + } + self.assertEqual(config, expected_config) + + @unittest.skipUnless(conftest.TEST_VERTEX_SERVICE_LIVE, conftest.SKIP_REASON) + def test_get_response(self): + response = self.service.get_response(conftest.PROMPT_TEXT) + # non-deterministic result from real service, just checking the type: + self.assertIsInstance(response, str) + + def test_get_response_mocked(self): + # mocked credentials: + creds = mock.create_autospec(credentials.Credentials, instance=True) + + # mocked client: + client = mock.create_autospec(genai.Client, instance=True) + generate_content_response = mock.MagicMock() + generate_content_response.text = conftest.RESPONSE_TEXT + client.models.generate_content.return_value = generate_content_response + + # dependency injection: + service = vertex_service.VertexAIService( + project_id='not-a-real-project', credentials=creds, client=client + ) + + # test the response: + response = service.get_response(conftest.PROMPT_TEXT) + self.assertIsInstance(response, str) + self.assertEqual(response, conftest.RESPONSE_TEXT) + + +class MockedVertexAIServiceTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.service = conftest.create_mock_vertex_service() + + def test_credentials(self): + self.assertIsInstance(self.service.credentials, credentials.Credentials) + + def test_get_response(self): + response = self.service.get_response(conftest.PROMPT_TEXT) + self.assertEqual(response, conftest.RESPONSE_TEXT) + + +if __name__ == '__main__': + absltest.main() diff --git a/smart_control/llm/utils/schedule_models.py b/smart_control/llm/utils/schedule_models.py new file mode 100644 index 00000000..f0b9eac5 --- /dev/null +++ b/smart_control/llm/utils/schedule_models.py @@ -0,0 +1,308 @@ +"""Schedule models. + +These models represent the daily and weekly operational schedules for a given +building, and are used by the schedule tool to determine the operational mode +of the building, based on the current day and time. + +These daily and weekly schedules are meant to provide general templates that are +applied to all weeks, and are not tied to specific dates. As such, they also do +not account for holidays, which should be incorporated separately. + +The models use timezone-aware on and off times, to ensure accurate comparisons. +""" + +import calendar +from collections.abc import Mapping +from collections.abc import Sequence +import dataclasses +import datetime +from typing import Any, Self +import zoneinfo + +# FYI: The calendar module may use different day names depending on the locale. +# We assume the calendar is set to the English locale. +DAY_NAMES = tuple(calendar.day_name) # ("Monday", "Tuesday", etc.) + + +def str_to_time_with_zone(time_str: str, time_zone: str) -> datetime.time: + """Returns a datetime.time object that is timezone aware.""" + tzinfo = zoneinfo.ZoneInfo(time_zone) + return datetime.time.fromisoformat(time_str).replace(tzinfo=tzinfo) + + +def display_time(time: datetime.time | None) -> str | None: + """Displays a given time as a string. + + This method is used for display and JSON serialization purposes only, not for + comparisons. + + It needs to handle None values because the daily schedule's on and off times + can be None (which designates a non-operational day). + + Args: + time: The time to convert, or None. + + Returns: + The time as a string, like "07:00", or None. + """ + return time.strftime("%H:%M") if time is not None else None + + +@dataclasses.dataclass(frozen=True) +class DailySchedule: + """The planned operational schedule for a given day of week. + + This model assumes that there is a single operational period for the day, and + that the building and its devices should be "ON" during the hours between the + `on_time` and `off_time`, and "OFF" otherwise. + + A given day should have both an `on_time` and `off_time` (to designate an + operational day), or neither (using None values to designate a non-operational + day). If present, both times must be timezone aware, and share the same time + zone. When comparing a time to the on_time and off_time, the comparison time + must also be timezone aware, and have the same time zone (see the + `is_during_operational_hours` method for more details). + + Attributes: + day_name: The name of the day of the week (e.g. "Monday"). + on_time: The time of day when devices should be turned on, or None. + off_time: The time of day when devices should be turned off, or None. + time_zone: The time zone used for the on_time and off_time. Required, even + if the on_time and off_time are None. + """ + + day_name: str + time_zone: str = "UTC" + on_time: datetime.time | None = None + off_time: datetime.time | None = None + + # VALIDATIONS + + def __post_init__(self) -> None: + self._validate_day_name() + self._validate_time_zone() + self._validate_times() + self._validate_times_zones() + self._validate_times_start_after_end() + + def _validate_day_name(self) -> None: + """Ensures the day name is valid.""" + if self.day_name not in DAY_NAMES: + raise ValueError( + f"Unknown day name: {self.day_name}. Expecting one of: {DAY_NAMES}." + ) + + def _validate_time_zone(self) -> None: + """Ensures the time zone is present and valid.""" + if self.time_zone is None: + raise ValueError("The time zone must be specified.") + + try: + self.tzinfo # pylint: disable=pointless-statement + except zoneinfo.ZoneInfoNotFoundError as err: + raise ValueError(f"Invalid time zone: {self.time_zone}.") from err + + def _validate_times(self) -> None: + """Ensures both times are present, or neither are.""" + if (self.on_time is None and self.off_time is not None) or ( + self.on_time is not None and self.off_time is None + ): + raise ValueError( + "The on_time and off_time must both be specified, or both be None." + ) + + def _validate_times_zones(self) -> None: + """Ensures both times have a time zone, and they match the schedule.""" + if self.on_time is None or self.off_time is None: + return + + if self.on_time.tzinfo is None: + raise ValueError("The on_time needs to have a time zone.") + + if self.off_time.tzinfo is None: + raise ValueError("The off_time needs to have a time zone.") + + if ( + self.on_time.tzinfo != self.tzinfo + or self.off_time.tzinfo != self.tzinfo + ): + raise ValueError( + "The on_time and off_time must have the same time zone, and it must " + f"match the schedule's time zone: {self.time_zone}." + ) + + def _validate_times_start_after_end(self) -> None: + """Ensures the on_time is before the off_time.""" + if self.on_time is not None and self.off_time is not None: + if self.on_time >= self.off_time: + raise ValueError("The on_time must be before the off_time.") + + # CONSTRUCTOR + + @classmethod + def from_times( + cls, + *, + day_name: str, + on_time: str | None, + off_time: str | None, + time_zone: str | None = "UTC", + ) -> Self: + """Creates a DailySchedule from 24-hr time strings. + + This method allows you to pass timezone-naive strings for convenience. It + will apply the specified time zone to each of the times to ensure they are + both timezone aware. + + Args: + day_name: The name of the day of the week (e.g. "Monday"). + on_time: The time of day when devices should be turned on, as a string + like "07:00", or None if the day is not operational. + off_time: The time of day when devices should be turned off, as a string + like "19:00", or None if the day is not operational. + time_zone: The time zone to use for the on_time and off_time. Defaults to + "UTC". + + Returns: + A DailySchedule instance. + """ + if on_time is not None: + on_time = str_to_time_with_zone(on_time, time_zone=time_zone) + + if off_time is not None: + off_time = str_to_time_with_zone(off_time, time_zone=time_zone) + + return cls( + day_name=day_name, + on_time=on_time, + off_time=off_time, + time_zone=time_zone, + ) + + # METHODS AND PROPERTIES + + @property + def tzinfo(self) -> zoneinfo.ZoneInfo: + """Information about the given time zone, as a zoneinfo.ZoneInfo object.""" + return zoneinfo.ZoneInfo(self.time_zone) + + @property + def is_operational_day(self) -> bool: + """Whether this day is scheduled to be an operational day.""" + return self.on_time is not None and self.off_time is not None + + def is_during_operational_hours(self, time: datetime.time) -> bool: + """Determines if the given time is within the scheduled hours. + + The comparison time needs to be timezone-aware, and have the same time zone + as the on_time and off_time, which have both already been validated to have + the same time zone. + + Note about edge cases: The start time is considered operational (inclusive), + but the end time is considered non-operational (exclusive). + + Args: + time: The time to check. Must be timezone aware, and have the same time + zone as the schedule. + + Returns: + A boolean indicating whether the given time falls within the scheduled + hours. + """ + if time.tzinfo is None: + raise ValueError("The comparison time must have a time zone.") + + if str(time.tzinfo) != str(self.tzinfo): + raise ValueError( + "The comparison time must have the same time zone as the schedule." + ) + + if not self.is_operational_day: + return False + + return self.on_time <= time < self.off_time + + +@dataclasses.dataclass(frozen=True) +class WeeklySchedule: + """The operational schedule for a given week. + + The weekly schedule contains a daily schedule for each day of the week. + + Attributes: + daily_schedules: A sequence of DailySchedules for each day of the week. + """ + + daily_schedules: Sequence[DailySchedule] + + # VALIDATIONS + + def __post_init__(self) -> None: + self._validate_all_days() + + def _validate_all_days(self) -> None: + """Ensures all expected day names are present.""" + day_names = [schedule.day_name for schedule in self.daily_schedules] + if sorted(day_names) != sorted(DAY_NAMES): + raise ValueError( + "Weekly schedule must have a schedule for each day of the week." + f" Expected: {DAY_NAMES}, got: {day_names}." + ) + + # CONSTRUCTOR + + @classmethod + def from_dict( + cls, + schedule_dict: Mapping[str, Sequence[str | None]], + time_zone: str | None = "UTC", + ) -> Self: + """Creates a WeeklySchedule from a dictionary of DailySchedules.""" + return cls([ + DailySchedule.from_times( + day_name=day_name, + on_time=on_time, + off_time=off_time, + time_zone=time_zone, + ) + for day_name, (on_time, off_time) in schedule_dict.items() + ]) + + # PROPERTIES AND METHODS + + @property + def time_zone(self) -> str: + """The time zone used for all the daily schedules.""" + return self.daily_schedules[0].time_zone + + def get_daily_schedule(self, day_name: str) -> DailySchedule: + """Returns the daily schedule for the given day of week. + + Args: + day_name: The name of the day of the week (e.g. "Monday"). + + Raises: + ValueError: If the day name is not in the weekly schedule. + + Returns: + The DailySchedule instance for the given day of week. + """ + for schedule in self.daily_schedules: + if schedule.day_name == day_name: + return schedule + raise ValueError(f"Unknown day name: {day_name}") + + @property + def json_metadata(self) -> dict[str, Any]: + """Info about the weekly schedule, in a JSON serializable format.""" + daily_schedules_dict = { + schedule.day_name: { + "on_time": display_time(schedule.on_time), + "off_time": display_time(schedule.off_time), + } + for schedule in self.daily_schedules + } + return { + "time_zone": self.time_zone, + "daily_schedules": daily_schedules_dict, + } diff --git a/smart_control/llm/utils/schedule_models_test.py b/smart_control/llm/utils/schedule_models_test.py new file mode 100644 index 00000000..28026b72 --- /dev/null +++ b/smart_control/llm/utils/schedule_models_test.py @@ -0,0 +1,346 @@ +import datetime +import zoneinfo + +from absl.testing import absltest +from absl.testing import parameterized +from smart_buildings.smart_control.llm.utils import schedule_models + +UTC = "UTC" +EST = "America/New_York" +PST = "America/Los_Angeles" + +UTC_INFO = zoneinfo.ZoneInfo(UTC) +EST_INFO = zoneinfo.ZoneInfo(EST) +PST_INFO = zoneinfo.ZoneInfo(PST) + +TIME = datetime.time(8, 0) # timezone naive +TIME_UTC = datetime.time(8, 0, tzinfo=UTC_INFO) +TIME_PST = datetime.time(8, 0, tzinfo=PST_INFO) +TIME_EST = datetime.time(8, 0, tzinfo=EST_INFO) + +OFF_TIME = datetime.time(18, 0) # timezone naive +OFF_TIME_UTC = datetime.time(18, 0, tzinfo=UTC_INFO) +OFF_TIME_PST = datetime.time(18, 0, tzinfo=PST_INFO) +OFF_TIME_EST = datetime.time(18, 0, tzinfo=EST_INFO) + + +class TimeConversionsTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name="naive_time_str_to_utc", + time_str="08:00", + time_zone=UTC, + expected=TIME_UTC, + ), + dict( + testcase_name="naive_time_str_to_eastern", + time_str="08:00", + time_zone=EST, + expected=TIME_EST, + ), + dict( + testcase_name="naive_time_str_to_pacific", + time_str="08:00", + time_zone=PST, + expected=TIME_PST, + ), + dict( + testcase_name="tz_eastern", + time_str="08:00", + time_zone=EST, + expected=TIME_EST, + ), + ) + def test_str_to_time_with_zone(self, time_str, time_zone, expected): + self.assertEqual( + schedule_models.str_to_time_with_zone(time_str, time_zone), expected + ) + + +# +# DAILY SCHEDULE TESTS +# + + +class OperationalDailyScheduleTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.schedule = schedule_models.DailySchedule.from_times( + day_name="Monday", on_time="08:00", off_time="18:00" + ) + + def test_init(self): + self.assertIsInstance(self.schedule, schedule_models.DailySchedule) + + def test_attributes(self): + self.assertEqual(self.schedule.day_name, "Monday") + self.assertEqual(self.schedule.on_time, TIME_UTC) + self.assertEqual(self.schedule.off_time, OFF_TIME_UTC) + self.assertEqual(self.schedule.time_zone, UTC) + + def test_is_operational_day(self): + self.assertTrue(self.schedule.is_operational_day) + + @parameterized.named_parameters( + dict(testcase_name="during_hours", hour=12, minute=0, expected=True), + dict(testcase_name="before_hours", hour=7, minute=0, expected=False), + dict(testcase_name="after_hours", hour=19, minute=0, expected=False), + dict(testcase_name="start_of_hours", hour=8, minute=0, expected=True), + dict(testcase_name="end_of_hours", hour=18, minute=0, expected=False), + ) + def test_is_during_operational_hours(self, hour, minute, expected): + self.assertEqual( + self.schedule.is_during_operational_hours( + datetime.time(hour, minute, tzinfo=UTC_INFO) + ), + expected, + ) + + def test_is_during_operational_hours_with_wrong_time_zone_raises(self): + with self.assertRaisesRegex( + ValueError, + "The comparison time must have the same time zone as the schedule.", + ): + self.schedule.is_during_operational_hours( + datetime.time(12, 0, tzinfo=PST_INFO) + ) + + def test_is_during_operational_hours_with_naive_time_raises(self): + with self.assertRaisesRegex( + ValueError, + "The comparison time must have a time zone.", + ): + self.schedule.is_during_operational_hours(datetime.time(12, 0)) + + +class NonOperationalDailyScheduleTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.schedule = schedule_models.DailySchedule.from_times( + day_name="Monday", on_time=None, off_time=None + ) + + def test_init(self): + self.assertIsInstance(self.schedule, schedule_models.DailySchedule) + + def test_attributes(self): + self.assertEqual(self.schedule.day_name, "Monday") + self.assertIsNone(self.schedule.on_time) + self.assertIsNone(self.schedule.off_time) + + def test_is_operational_day(self): + self.assertFalse(self.schedule.is_operational_day) + + @parameterized.parameters( + datetime.time(12, 0, tzinfo=UTC_INFO), + datetime.time(7, 0, tzinfo=UTC_INFO), + datetime.time(19, 0, tzinfo=UTC_INFO), + ) + def test_is_during_operational_hours(self, time): + self.assertFalse(self.schedule.is_during_operational_hours(time)) + + +class DailyScheduleValidationsTest(absltest.TestCase): + + def test_invalid_day_name_raises(self): + with self.assertRaisesRegex(ValueError, "Unknown day name: Funday"): + schedule_models.DailySchedule.from_times( + day_name="Funday", on_time="08:00", off_time="18:00" + ) + + def test_missing_on_time_raises(self): + with self.assertRaisesRegex( + ValueError, + "The on_time and off_time must both be specified, or both be None.", + ): + schedule_models.DailySchedule.from_times( + day_name="Monday", on_time=None, off_time="18:00" + ) + + def test_missing_off_time_raises(self): + with self.assertRaisesRegex( + ValueError, + "The on_time and off_time must both be specified, or both be None.", + ): + schedule_models.DailySchedule.from_times( + day_name="Monday", on_time="08:00", off_time=None + ) + + def test_on_after_off_raises(self): + with self.assertRaisesRegex( + ValueError, "The on_time must be before the off_time." + ): + schedule_models.DailySchedule.from_times( + day_name="Monday", on_time="18:00", off_time="08:00" + ) + + def test_same_on_and_off_raises(self): + with self.assertRaisesRegex( + ValueError, "The on_time must be before the off_time." + ): + schedule_models.DailySchedule.from_times( + day_name="Monday", on_time="08:00", off_time="08:00" + ) + + def test_invalid_time_zone_raises(self): + with self.assertRaisesRegex(ValueError, "Invalid time zone: OOPS"): + schedule_models.DailySchedule.from_times( + day_name="Monday", + on_time=None, + off_time=None, + time_zone="OOPS", + ) + + def test_naive_on_time_raises(self): + with self.assertRaisesRegex( + ValueError, "The on_time needs to have a time zone." + ): + schedule_models.DailySchedule( + day_name="Monday", + on_time=TIME, + off_time=OFF_TIME_UTC, + time_zone=UTC, + ) + + def test_naive_off_time_raises(self): + with self.assertRaisesRegex( + ValueError, "The off_time needs to have a time zone." + ): + schedule_models.DailySchedule( + day_name="Monday", + on_time=TIME_UTC, + off_time=OFF_TIME, + time_zone=UTC, + ) + + def test_mismatched_on_time_tz_raises(self): + with self.assertRaisesRegex( + ValueError, + "The on_time and off_time must have the same time zone", + ): + schedule_models.DailySchedule( + day_name="Monday", + on_time=TIME_PST, + off_time=OFF_TIME_UTC, + time_zone=UTC, + ) + + def test_mismatched_off_time_tz_raises(self): + with self.assertRaisesRegex( + ValueError, + "The on_time and off_time must have the same time zone", + ): + schedule_models.DailySchedule( + day_name="Monday", + on_time=TIME_UTC, + off_time=OFF_TIME_PST, + time_zone=UTC, + ) + + +# +# WEEKLY SCHEDULE TESTS +# + + +class WeeklyScheduleTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.schedule_dict = { + "Monday": ("06:00", "19:00"), + "Tuesday": ("06:00", "19:00"), + "Wednesday": ("06:00", "19:00"), + "Thursday": ("06:00", "19:00"), + "Friday": ("09:00", "17:00"), + "Saturday": (None, None), + "Sunday": (None, None), + } + self.weekly_schedule = schedule_models.WeeklySchedule.from_dict( + schedule_dict=self.schedule_dict, time_zone=UTC + ) + + def test_init(self): + self.assertIsInstance(self.weekly_schedule, schedule_models.WeeklySchedule) + + def test_day_names(self): + self.assertLen(self.weekly_schedule.daily_schedules, 7) + + day_names = [ + schedule.day_name for schedule in self.weekly_schedule.daily_schedules + ] + self.assertCountEqual(day_names, list(schedule_models.DAY_NAMES)) + + def test_time_zone(self): + self.assertEqual(self.weekly_schedule.time_zone, UTC) + + def test_get_daily_schedule(self): + monday_schedule = self.weekly_schedule.get_daily_schedule("Monday") + self.assertEqual(monday_schedule.day_name, "Monday") + self.assertEqual( + monday_schedule.on_time, datetime.time(6, 0, tzinfo=UTC_INFO) + ) + self.assertEqual( + monday_schedule.off_time, datetime.time(19, 0, tzinfo=UTC_INFO) + ) + + def test_get_daily_schedule_with_invalid_day_name_raises(self): + with self.assertRaisesRegex(ValueError, "Unknown day name: Funday"): + self.weekly_schedule.get_daily_schedule("Funday") + + def test_json_metadata(self): + self.assertEqual( + self.weekly_schedule.json_metadata, + { + "time_zone": UTC, + "daily_schedules": { + "Monday": {"on_time": "06:00", "off_time": "19:00"}, + "Tuesday": {"on_time": "06:00", "off_time": "19:00"}, + "Wednesday": {"on_time": "06:00", "off_time": "19:00"}, + "Thursday": {"on_time": "06:00", "off_time": "19:00"}, + "Friday": {"on_time": "09:00", "off_time": "17:00"}, + "Saturday": {"on_time": None, "off_time": None}, + "Sunday": {"on_time": None, "off_time": None}, + }, + }, + ) + + +class WeeklyScheduleValidationsTest(absltest.TestCase): + + def test_missing_day_raises(self): + with self.assertRaisesRegex( + ValueError, + "Weekly schedule must have a schedule for each day of the week.", + ): + schedule_models.WeeklySchedule.from_dict( + {"Monday": ("08:00", "18:00")}, time_zone=PST + ) + + def test_extra_day_raises(self): + # FYI because dictionaries don't allow duplicate keys, we can't use the + # WeeklySchedule.from_dict constructor to test this validation. + with self.assertRaisesRegex( + ValueError, + "Weekly schedule must have a schedule for each day of the week.", + ): + from_times = schedule_models.DailySchedule.from_times + on_time = "08:00" + off_time = "18:00" + schedule_models.WeeklySchedule([ + from_times(day_name="Monday", on_time=on_time, off_time=off_time), + from_times(day_name="Tuesday", on_time=on_time, off_time=off_time), + from_times(day_name="Wednesday", on_time=on_time, off_time=off_time), + from_times(day_name="Thursday", on_time=on_time, off_time=off_time), + from_times(day_name="Friday", on_time=on_time, off_time=off_time), + from_times(day_name="Saturday", on_time=None, off_time=None), + from_times(day_name="Sunday", on_time=None, off_time=None), + from_times(day_name="Sunday", on_time=None, off_time=None), # Extra + ]) + + +if __name__ == "__main__": + absltest.main() diff --git a/smart_control/llm/utils/schedule_tool.py b/smart_control/llm/utils/schedule_tool.py new file mode 100644 index 00000000..546c89bf --- /dev/null +++ b/smart_control/llm/utils/schedule_tool.py @@ -0,0 +1,373 @@ +"""Schedule tool. + +This tool provides information about the building's operational schedule, by +accessing information such as the current date and time from the environment. + +**Operational Modes** + +This tool can be used by an agent to determine if the building's devices should +be ON or OFF, based on the time of day, day of week, and holiday calendar. + +**Weekly Schedule** + +By default, this tool assumes that workdays are Mondays through Fridays, and +that operational hours are from 7:00 AM to 7:00 PM, but these values can be +customized. This includes the ability to specify different operational hours for +different days of the week. See the `schedule_models.WeeklySchedule` class for +more details. + +**Holiday Calendar** + +We anticipate the need to customize the holiday calendar, because we will be +supporting buildings across different countries. And because even within a given +country, different localities, companies, and building operators may observe +slightly different holiday schedules. + +By default, this tool uses the `holiday.USFederalHolidayCalendar` to determine +the holidays, which provides a good baseline for US-based buildings. However, +you can specify a different holiday calendar, as long as it implements the +`holiday.AbstractHolidayCalendar` interface from pandas (as illustrated by the +example below). + +```python +from pandas.tseries import holiday + +class MyCustomHolidayCalendar(holiday.AbstractHolidayCalendar): + rules = [ + holiday.Holiday("Founder's Day", month=7, day=1), + holiday.Holiday("My Birthday", month=9, day=1), + ] +``` +""" + +import abc +import datetime +import enum +from typing import Any, Final, TypeAlias + +import pandas as pd +from pandas.tseries import holiday +from smart_buildings.smart_control.environment import environment +from smart_buildings.smart_control.llm.utils import schedule_models + +SerializableData: TypeAlias = dict[str, Any] + + +class BuildingOperationalMode(enum.Enum): + """The operational mode of the building (and its devices).""" + + ON = "ON" + OFF = "OFF" + + +OPERATIONAL_MODES = tuple(mode.value for mode in BuildingOperationalMode) + + +DEFAULT_WEEKLY_SCHEDULE: Final[schedule_models.WeeklySchedule] = ( + schedule_models.WeeklySchedule.from_dict( + schedule_dict={ + "Monday": ("07:00", "19:00"), + "Tuesday": ("07:00", "19:00"), + "Wednesday": ("07:00", "19:00"), + "Thursday": ("07:00", "19:00"), + "Friday": ("07:00", "19:00"), + "Saturday": (None, None), + "Sunday": (None, None), + }, + time_zone="US/Pacific", + ) +) + + +class BaseSchedule(abc.ABC): + """Abstract interface providing info about a building's operational schedule. + + Requires a child class to implement the `time_zone` and + `current_local_timestamp` properties, using the building's local time zone. + + Determines if the building's devices should be ON or OFF, based on the time of + day, day of week, and holiday calendar. + + For the holiday calendar, the US federal holiday calendar will be used by + default, however you can customize this by passing in your own implementation + of the `holiday.AbstractHolidayCalendar` interface from pandas. + + The start and end dates are optionally used to filter the range of holidays + included. If not specified, holidays from all available years will be + included. + + Attributes: + time_zone: The time zone to use for all date and time calculations. + current_local_timestamp: The current date and time in the local time zone. + weekly_schedule: The operational hours for each day of the week, using the + building's local time zone. + cal: The holiday calendar to use for determining holidays. Defaults to the + US federal holiday calendar. + start_date: The start date for the holiday calendar (optional). + end_date: The end date for the holiday calendar (optional). + n_upcoming_holidays: The number of upcoming holidays to return. + """ + + def __init__( + self, + weekly_schedule: schedule_models.WeeklySchedule | None = None, + cal: holiday.AbstractHolidayCalendar | None = None, + start_date: str | None = None, + end_date: str | None = None, + n_upcoming_holidays: int = 5, + ): + """Initializes the instance. + + Args: + weekly_schedule: The operational schedule for the week. Defaults to + `DEFAULT_WEEKLY_SCHEDULE`. + cal: The holiday calendar to use for determining holidays. The calendar + must implement the `holiday.AbstractHolidayCalendar` interface. + By default, the US federal holiday calendar is used. + start_date: The start date used to optionally filter the list of + holidays. Defaults to None. + end_date: The end date used to optionally filter the list of + holidays. Defaults to None. + n_upcoming_holidays: The number of upcoming holidays to return. + """ + self.weekly_schedule = weekly_schedule or DEFAULT_WEEKLY_SCHEDULE + self.start_date = start_date + self.end_date = end_date + self.cal = cal or holiday.USFederalHolidayCalendar() + self.n_upcoming_holidays = n_upcoming_holidays + + # + # BASE CONTRACT + # + + @property + @abc.abstractmethod + def time_zone(self) -> str: + """The time zone used for all timestamps and comparisons.""" + raise NotImplementedError + + @property + @abc.abstractmethod + def current_local_timestamp(self) -> pd.Timestamp: + """The current (timezone-aware) date and time in the local timezone.""" + raise NotImplementedError + + # + # IMPLEMENTATION METHODS + # + + @property + def json_metadata(self) -> SerializableData: + """Info to write into a JSON file. Needs to be serializable.""" + holidays_df = self.upcoming_holidays_df.copy() + holidays_df["date"] = holidays_df["date"].dt.strftime("%Y-%m-%d") + holidays_df = holidays_df.rename(columns={"holiday": "name"}) + holidays = holidays_df[["date", "name", "day_name"]].to_dict("records") + + return { + "weekly_schedule": self.weekly_schedule.json_metadata, + "start_date": self.start_date, + "end_date": self.end_date, + "upcoming_holidays": holidays, + } + + # CURRENT DATE AND TIME + + @property + def current_year(self) -> int: + """The current year, in the building's local timezone.""" + return self.current_local_timestamp.year + + @property + def current_date(self) -> datetime.date: + """The current date, in the building's local timezone.""" + return self.current_local_timestamp.date() + + @property + def current_date_str(self) -> str: + """The current date as a string, in the building's local timezone.""" + return self.current_local_timestamp.strftime("%Y-%m-%d") + + @property + def current_time(self) -> datetime.time: + """The current (timezone-aware) time, in the building's local timezone.""" + return self.current_local_timestamp.timetz() + + @property + def current_time_str(self) -> str: + """The current time as a string, in the building's local timezone.""" + return self.current_local_timestamp.strftime("%H:%M") + + @property + def current_weekday_name(self) -> str: + """The current day of the week, in the building's local timezone.""" + return self.current_local_timestamp.strftime("%A") # > "Monday" + + # HOLIDAY CALENDAR + + def _get_holidays( + self, return_name: bool = False + ) -> pd.DatetimeIndex | pd.Series: + """Returns the holidays as a DatetimeIndex or a Series. + + Args: + return_name: Whether to return the holidays as a Series. + + Returns: + A DatetimeIndex or a Series of the holidays. + """ + return self.cal.holidays( + start=self.start_date, end=self.end_date, return_name=return_name + ) + + @property + def holidays(self) -> set[str]: + """The holiday calendar, as a set of string dates (like '2025-01-01').""" + return { + d.strftime("%Y-%m-%d") for d in self._get_holidays(return_name=False) + } + + @property + def holidays_df(self) -> pd.DataFrame: + """The holiday calendar, as a DataFrame.""" + df = self._get_holidays(return_name=True).reset_index() + df.columns = ["date", "holiday"] + df["day_of_year"] = df["date"].dt.dayofyear + df["year"] = df["date"].dt.year + df["day_name"] = df["date"].dt.day_name() + return df + + @property + def upcoming_holidays_df(self) -> pd.DataFrame: + """The next few upcoming holidays, as a DataFrame. + + Use the `n_upcoming_holidays` initialization argument to customize the + number of holidays to be included. + + Note: It is possible for this dataframe to contain fewer than the requested + number of holidays, depending on the current date and the end date. + + Returns: + A DataFrame of the next few upcoming holidays, sorted by date ascending. + """ + df = self.holidays_df[self.holidays_df["date"].dt.date >= self.current_date] + df.sort_values(by="date", inplace=True) + return df.head(self.n_upcoming_holidays) + + @property + def upcoming_holidays(self) -> list[str]: + """The next few upcoming holidays. + + Use the `n_upcoming_holidays` initialization argument to customize the + number of holidays to return. + + Note: It is possible for this list to contain fewer than the requested + number of holidays, depending on the current date and the end date. + + Returns: + A list of strings, like '2025-01-01', sorted by date ascending. + """ + return self.upcoming_holidays_df["date"].dt.strftime("%Y-%m-%d").tolist() + + @property + def is_holiday(self) -> bool: + """Whether the current date is a holiday.""" + return self.current_date_str in self.holidays + + # WEEKLY SCHEDULE + + @property + def current_daily_schedule(self) -> schedule_models.DailySchedule: + """The daily schedule for the current day of week.""" + return self.weekly_schedule.get_daily_schedule(self.current_weekday_name) + + @property + def is_workday(self) -> bool: + """Whether the current date is a workday (not considering holidays).""" + return self.current_daily_schedule.is_operational_day + + # CURRENT OPERATIONAL STATUS + + @property + def is_operational_day(self) -> bool: + """Whether the current date is an operational day.""" + return self.is_workday and not self.is_holiday + + @property + def is_during_operational_hours(self) -> bool: + """Whether the current time is during operational hours.""" + return self.current_daily_schedule.is_during_operational_hours( + self.current_time + ) + + @property + def building_is_operational(self) -> bool: + """Whether the building is operational.""" + return self.is_operational_day and self.is_during_operational_hours + + @property + def building_operational_mode(self) -> BuildingOperationalMode: + """The building's operational mode.""" + if self.building_is_operational: + return BuildingOperationalMode.ON + else: + return BuildingOperationalMode.OFF + + +class ScheduleTool(BaseSchedule): + """Schedule tool using the current date and time in a specified time zone.""" + + def __init__(self, time_zone: str = "UTC", **kwargs): + """Initializes the instance. + + Args: + time_zone: The time zone to use for all date and time calculations. + Defaults to UTC. + **kwargs: Keyword arguments to pass to the base class. + """ + super().__init__(**kwargs) + self._time_zone = time_zone + + @property + def time_zone(self) -> str: + """Returns the time zone used for all date and time calculations.""" + return self._time_zone + + @property + def current_local_timestamp(self) -> pd.Timestamp: + """The current date and time in the local timezone.""" + return pd.Timestamp.now(tz=self.time_zone) + + +class BuildingScheduleTool(BaseSchedule): + """A tool for accessing information about the building's operational schedule. + + Uses the time zone and current local timestamp from the environment to + determine if the building's devices should be ON or OFF, based on the time of + day, day of week, and holiday calendar. + + Attributes: + env: The environment to use for getting the time zone and current timestamp. + **kwargs: Keyword arguments to pass to the base class. + """ + + def __init__(self, env: environment.Environment, **kwargs): + """Initializes the instance. + + Args: + env: The environment to use for getting the time zone and current + timestamp. + **kwargs: Keyword arguments to pass to the base class. + """ + super().__init__(**kwargs) + self.env = env + + @property + def time_zone(self) -> str: + """The building's local time zone, from the environment.""" + return self.env.time_zone + + @property + def current_local_timestamp(self) -> pd.Timestamp: + """The current date and time, in the building's local timezone.""" + return self.env.current_local_timestamp diff --git a/smart_control/llm/utils/schedule_tool_test.py b/smart_control/llm/utils/schedule_tool_test.py new file mode 100644 index 00000000..8ce29299 --- /dev/null +++ b/smart_control/llm/utils/schedule_tool_test.py @@ -0,0 +1,376 @@ +import datetime +from unittest import mock +import zoneinfo + +from absl.testing import absltest +from absl.testing import parameterized +import pandas as pd +from pandas.tseries import holiday +from smart_buildings.smart_control.environment import conftest as env_conftest +from smart_buildings.smart_control.llm.utils import schedule_models +from smart_buildings.smart_control.llm.utils import schedule_tool + +BuildingOperationalMode = schedule_tool.BuildingOperationalMode + +TIME_ZONE = "US/Pacific" +CURRENT_LOCAL_TIMESTAMP = pd.Timestamp("2021-06-01 12:00:00", tz=TIME_ZONE) + +UPCOMING_HOLIDAYS = ( + { + "date": pd.Timestamp("2021-06-18 00:00:00"), + "holiday": "Juneteenth National Independence Day", + "day_of_year": 169, + "year": 2021, + "day_name": "Friday", + }, + { + "date": pd.Timestamp("2021-07-05 00:00:00"), + "holiday": "Independence Day", + "day_of_year": 186, + "year": 2021, + "day_name": "Monday", + }, + { + "date": pd.Timestamp("2021-09-06 00:00:00"), + "holiday": "Labor Day", + "day_of_year": 249, + "year": 2021, + "day_name": "Monday", + }, + { + "date": pd.Timestamp("2021-10-11 00:00:00"), + "holiday": "Columbus Day", + "day_of_year": 284, + "year": 2021, + "day_name": "Monday", + }, + { + "date": pd.Timestamp("2021-11-11 00:00:00"), + "holiday": "Veterans Day", + "day_of_year": 315, + "year": 2021, + "day_name": "Thursday", + }, +) + +SCHEDULE_SCENARIOS = ( + { + "testcase_name": "weekday_morning", + "timestamp": pd.Timestamp("2025-12-12 08:00:00", tz=TIME_ZONE), + "weekday_name": "Friday", + "is_workday": True, + "is_holiday": False, + "is_operational_day": True, + "is_during_operational_hours": True, + "is_operational": True, + "operational_mode": schedule_tool.BuildingOperationalMode.ON, + }, + { + "testcase_name": "weekday_afternoon", + "timestamp": pd.Timestamp("2025-12-12 15:30:00", tz=TIME_ZONE), + "weekday_name": "Friday", + "is_workday": True, + "is_holiday": False, + "is_operational_day": True, + "is_during_operational_hours": True, + "is_operational": True, + "operational_mode": schedule_tool.BuildingOperationalMode.ON, + }, + { + "testcase_name": "weekday_nighttime", + "timestamp": pd.Timestamp("2025-12-12 02:00:00", tz=TIME_ZONE), + "weekday_name": "Friday", + "is_workday": True, + "is_holiday": False, + "is_operational_day": True, + "is_during_operational_hours": False, + "is_operational": False, + "operational_mode": schedule_tool.BuildingOperationalMode.OFF, + }, + { + "testcase_name": "holiday_daytime", # Christmas, a Thursday + "timestamp": pd.Timestamp("2025-12-25 11:00:00", tz=TIME_ZONE), + "weekday_name": "Thursday", + "is_workday": True, + "is_holiday": True, + "is_operational_day": False, + "is_during_operational_hours": True, + "is_operational": False, + "operational_mode": schedule_tool.BuildingOperationalMode.OFF, + }, + { + "testcase_name": "weekend_nighttime", + "timestamp": pd.Timestamp("2025-12-13 02:00:00", tz=TIME_ZONE), + "weekday_name": "Saturday", + "is_workday": False, + "is_holiday": False, + "is_operational_day": False, + "is_during_operational_hours": False, + "is_operational": False, + "operational_mode": schedule_tool.BuildingOperationalMode.OFF, + }, +) + +SCHEDULE_METADATA = { + "weekly_schedule": { + "time_zone": "US/Pacific", + "daily_schedules": { + "Monday": {"on_time": "07:00", "off_time": "19:00"}, + "Tuesday": {"on_time": "07:00", "off_time": "19:00"}, + "Wednesday": {"on_time": "07:00", "off_time": "19:00"}, + "Thursday": {"on_time": "07:00", "off_time": "19:00"}, + "Friday": {"on_time": "07:00", "off_time": "19:00"}, + "Saturday": {"on_time": None, "off_time": None}, + "Sunday": {"on_time": None, "off_time": None}, + }, + }, + "start_date": None, + "end_date": None, + "upcoming_holidays": [ + { + "date": "2021-06-18", + "name": "Juneteenth National Independence Day", + "day_name": "Friday", + }, + { + "date": "2021-07-05", + "name": "Independence Day", + "day_name": "Monday", + }, + {"date": "2021-09-06", "name": "Labor Day", "day_name": "Monday"}, + {"date": "2021-10-11", "name": "Columbus Day", "day_name": "Monday"}, + {"date": "2021-11-11", "name": "Veterans Day", "day_name": "Thursday"}, + ], +} + + +class ScheduleToolTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.mock_timestamp_now = self.enter_context( + mock.patch.object(pd.Timestamp, "now", autospec=True) + ) + self.mock_timestamp_now.return_value = CURRENT_LOCAL_TIMESTAMP + self.schedule = schedule_tool.ScheduleTool(time_zone=TIME_ZONE) + self.expected_class = schedule_tool.ScheduleTool + + def test_initialization(self): + self.assertIsInstance(self.schedule, self.expected_class) + + def test_weekly_schedule(self): + self.assertIsInstance( + self.schedule.weekly_schedule, schedule_models.WeeklySchedule + ) + + def test_time_zone(self): + self.assertEqual(self.schedule.time_zone, TIME_ZONE) + + def test_holiday_calendar(self): + self.assertIsNone(self.schedule.start_date) + self.assertIsNone(self.schedule.end_date) + self.assertIsInstance(self.schedule.cal, holiday.USFederalHolidayCalendar) + + # CURRENT DATE AND TIME + + def test_date_time_properties(self): + with self.subTest(name="current_local_timestamp"): + self.assertEqual( + self.schedule.current_local_timestamp, CURRENT_LOCAL_TIMESTAMP + ) + + with self.subTest(name="current_year"): + self.assertEqual(self.schedule.current_year, 2021) + + with self.subTest(name="current_date"): + self.assertEqual(self.schedule.current_date, datetime.date(2021, 6, 1)) + self.assertEqual(self.schedule.current_date_str, "2021-06-01") + + with self.subTest(name="current_time"): + self.assertEqual( + self.schedule.current_time, + datetime.time(12, 0, tzinfo=zoneinfo.ZoneInfo(TIME_ZONE)), + ) + self.assertEqual(self.schedule.current_time_str, "12:00") + + # HOLIDAY CALENDAR + + def test_get_holidays(self): + with self.subTest(name="as_index"): + holidays = self.schedule._get_holidays(return_name=False) + self.assertIsInstance(holidays, pd.DatetimeIndex) + + with self.subTest(name="as_series"): + holidays = self.schedule._get_holidays(return_name=True) + self.assertIsInstance(holidays, pd.Series) + + def test_holidays(self): + holidays = self.schedule.holidays + self.assertIsInstance(holidays, set) + self.assertGreaterEqual(len(holidays), 2474) + self.assertIn("1970-01-01", holidays) + self.assertIn("2200-12-25", holidays) + + def test_holidays_df(self): + df = self.schedule.holidays_df + self.assertIsInstance(df, pd.DataFrame) + self.assertGreaterEqual(len(df), 2474) + self.assertListEqual( + df.columns.tolist(), + ["date", "holiday", "day_of_year", "year", "day_name"], + ) + + holidays = df["date"].dt.strftime("%Y-%m-%d").tolist() + self.assertIn("1970-01-01", holidays) + self.assertIn("2200-12-25", holidays) + + def test_upcoming_holidays_df(self): + self.assertEqual( + self.schedule.upcoming_holidays_df.to_dict("records"), + list(UPCOMING_HOLIDAYS), + ) + + def test_upcoming_holidays(self): + self.assertEqual( + self.schedule.upcoming_holidays, + [h["date"].strftime("%Y-%m-%d") for h in UPCOMING_HOLIDAYS], + ) + + def test_is_holiday(self): + self.assertFalse(self.schedule.is_holiday) + + def test_json_metadata(self): + self.assertEqual(self.schedule.json_metadata, SCHEDULE_METADATA) + + # DAY OF WEEK + + def test_current_weekday_name(self): + self.assertEqual(self.schedule.current_weekday_name, "Tuesday") + + def test_is_workday(self): + self.assertTrue(self.schedule.is_workday) + + # CURRENT OPERATIONAL STATUS + + def test_is_operational_day(self): + self.assertTrue(self.schedule.is_operational_day) + + def test_is_during_operational_hours(self): + self.assertTrue(self.schedule.is_during_operational_hours) + + def test_building_is_operational(self): + self.assertTrue(self.schedule.building_is_operational) + + def test_building_operational_mode(self): + self.assertEqual( + self.schedule.building_operational_mode, + schedule_tool.BuildingOperationalMode.ON, + ) + + +class BuildingScheduleToolTest(ScheduleToolTest): + + def setUp(self): + super().setUp() + self.env = env_conftest.create_environment( + start_timestamp=CURRENT_LOCAL_TIMESTAMP + ) + self.schedule = schedule_tool.BuildingScheduleTool(env=self.env) + self.expected_class = schedule_tool.BuildingScheduleTool + + +# +# SCENARIO TESTS +# + + +class ScheduleScenariosTest(parameterized.TestCase): + """Performs scenario testing for different operational modes.""" + + @parameterized.named_parameters(*SCHEDULE_SCENARIOS) + def test_building_operation_schedule( + self, + timestamp, + weekday_name, + is_workday, + is_holiday, + is_operational_day, + is_during_operational_hours, + is_operational, + operational_mode, + ): + env = env_conftest.create_environment(start_timestamp=timestamp) + schedule = schedule_tool.BuildingScheduleTool(env=env) + with self.subTest(name="current_date_and_time"): + self.assertEqual(schedule.current_local_timestamp, timestamp) + self.assertEqual(schedule.current_weekday_name, weekday_name) + self.assertEqual(schedule.is_workday, is_workday) + self.assertEqual( + schedule.is_during_operational_hours, is_during_operational_hours + ) + + with self.subTest(name="holiday_calendar"): + self.assertEqual(schedule.is_holiday, is_holiday) + + with self.subTest(name="operational_status"): + self.assertEqual(schedule.is_operational_day, is_operational_day) + self.assertEqual(schedule.building_is_operational, is_operational) + self.assertEqual(schedule.building_operational_mode, operational_mode) + + +# +# CUSTOM HOLIDAY CALENDAR TESTS +# + + +class MyCustomHolidayCalendar(holiday.AbstractHolidayCalendar): + """Custom holiday calendar for testing.""" + + rules = [ + holiday.Holiday("Founder's Day", month=7, day=1), + holiday.Holiday("My Birthday", month=9, day=1), + ] + + +class CustomHolidayScheduleTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.mock_timestamp_now = self.enter_context( + mock.patch.object(pd.Timestamp, "now", autospec=True) + ) + self.custom_calendar = MyCustomHolidayCalendar() + + @parameterized.named_parameters( + dict( + testcase_name="founders_day", + timestamp="2024-07-01 10:00:00", + is_holiday=True, + ), + dict( + testcase_name="my_birthday", + timestamp="2024-09-01 10:00:00", + is_holiday=True, + ), + dict( + testcase_name="christmas_day", + timestamp="2024-12-25 10:00:00", + is_holiday=False, + ), + dict( + testcase_name="new_years_day", + timestamp="2025-01-01 10:00:00", + is_holiday=False, + ), + ) + def test_custom_holidays(self, timestamp, is_holiday): + self.mock_timestamp_now.return_value = pd.Timestamp(timestamp, tz=TIME_ZONE) + schedule = schedule_tool.ScheduleTool( + time_zone=TIME_ZONE, + cal=self.custom_calendar, + ) + self.assertEqual(schedule.is_holiday, is_holiday) + + +if __name__ == "__main__": + absltest.main()