diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3391ead5c..7f73b5cf5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -28,7 +28,7 @@ jobs: ports: - 5784:5672 postgres: - image: postgres:16 + image: ghcr.io/pgmq/pg18-pgmq:v1.10.0 ports: - 5544:5432 env: diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index a873de1eb..0d6510bcc 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -30,7 +30,11 @@ of those changes to CLEARTYPE SRL. | [@thomasLeMeur](https://github.com/thomasLeMeur) | Thomas Le Meur | | [@fregogui](https://github.com/fregogui) | Guillaume Fregosi | | [@pgitips](https://github.com/pgitips) | Pierre Giraud | + + + | [@williampollet](https://github.com/williampollet) | William Pollet | | [@mehdithez](https://github.com/mehdithez) | Zeroual Mehdi | -| [@julien-duponchelle](https://github.com/julien-duponchelle) | Julien Duponchelle | -| [@alisterd51](https://github.com/alisterd51) | Antoine Clarman | +| [@mducros-wm](https://github.com/mducros-wm) | Martin Ducros | +| [@julien-duponchelle](https://github.com/julien-duponchelle) | Julien Duponchelle | +| [@alisterd51](https://github.com/alisterd51) | Antoine Clarman | diff --git a/README.md b/README.md index c4a929758..66f0c555f 100644 --- a/README.md +++ b/README.md @@ -23,10 +23,17 @@ If you want to use it with [RabbitMQ] uv pip install 'remoulade[rabbitmq]' ``` -or if you want to use it with [Redis] +If you want to use it with [PostgreSQL] and [PGMQ] ```console - uv pip install 'remoulade[redis]' + uv pip install 'remoulade[postgres]' +``` + +If you want Redis-backed extras like results or cancellation, add [Redis] to the broker extra you use: + +```console + uv pip install 'remoulade[rabbitmq, redis]' + uv pip install 'remoulade[postgres, redis]' ``` ## Quickstart @@ -84,12 +91,6 @@ If you want to contribute to the project. First make a Pull request and get appr This will trigger a CI/CD pipeline that publish the package -## Dashboard - -Check out [SuperBowl](https://github.com/wiremind/super-bowl) a dashboard for real-time monitoring and administrating all your Remoulade tasks. -***See the current progress, enqueue, requeue, cancel and more ...*** -Super easy to use !. - ## Kubernetes Remoulade is tailored to run transparently in containers on [Kubernetes](https://kubernetes.io/) and to make the most of their features. This does not mean it cannot run outside of Kubernetes ;) @@ -110,6 +111,8 @@ remoulade is licensed under the LGPL. Please see [COPYING] and [COPYING.LESSER]: https://github.com/wiremind/remoulade/blob/master/COPYING.LESSER [COPYING]: https://github.com/wiremind/remoulade/blob/master/COPYING +[PostgreSQL]: https://www.postgresql.org/ +[PGMQ]: https://pgmq.github.io/pgmq/ [RabbitMQ]: https://www.rabbitmq.com/ [Redis]: https://redis.io [user guide]: https://remoulade.readthedocs.io/en/latest/guide.html diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 6f0446b84..5c4dd4e4e 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -5,6 +5,27 @@ Changelog All notable changes to this project will be documented in this file. +`7.0.0`_ -- 2026-06-15 +------------ +Breaking changes +^^^^^^^^^^^^^^^^ +* Remove the legacy PostgreSQL state backend. +* Remove the ``DELETE /messages/states`` API route, which only worked with the removed PostgreSQL state backend. +* Remove the ``PUT /scheduled/jobs/`` API route (single-job update); use ``PUT /scheduled/jobs`` instead. +* Rework the broker API around the new PostgreSQL/PGMQ implementation. +* Rename ``Encoder.encode``/``Encoder.decode`` to ``Encoder.encode_in_bytes``/``Encoder.decode_bytes``; custom encoders must now also implement ``Encoder._encode_in_json`` and ``Encoder.decode_json``. +* Rename ``Message.encode``/``Message.decode`` to ``Message.encode_in_bytes``/``Message.decode_bytes``. +* ``PydanticEncoder`` no longer depends on ``simplejson``; serialization now goes through Pydantic, so ``Decimal`` values are encoded as JSON strings instead of numbers. +* Repurpose the ``postgres`` extra: it now installs the PGMQ broker dependencies (``sqlalchemy>=2``, ``psycopg>=3``, ``pgmq``) instead of the legacy state backend dependencies (``sqlalchemy<2``, ``psycopg2``). + +Feat +^^^^ +* Add a PostgreSQL/PGMQ broker with partitioned queues, native delayed messages, ``LISTEN/NOTIFY`` wakeups, and queue join support. +Changed +^^^^^^^ +* Restore the main APIs after the broker refactor. +* Update the documentation, examples, CI, and test suite for the new stack. + ======= `6.2.0`_ -- 2026-05-18 ------------- diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst index 23303cf4e..c3405fb98 100644 --- a/docs/source/getting_started.rst +++ b/docs/source/getting_started.rst @@ -12,7 +12,6 @@ By the end of this tutorial, you will be able to do the following: * :ref:`create a pipeline of tasks that will sequentially process data` * :ref:`use the result middleware to wait and get actor results` * :ref:`use the remoulade scheduler to periodically run tasks` -* :ref:`use SuperBowl to monitor and manage your tasks` Prerequisites ------------- @@ -341,55 +340,6 @@ To set up the scheduler, we instantiate it, set it as the global scheduler, and If you run this script and get back to the worker terminal, you will see ``get_weather`` being executed every 10 seconds. -Monitoring and Managing your tasks ----------------------------------- - -To monitor and manage your tasks, you can use the Superbowl_ dashboard. - -.. _Superbowl: https://github.com/wiremind/super-bowl - -First, you will need to install Node.js_. Then, clone Superbowl_ in another directory, install its dependencies and run it:: - - $ cd .. - $ git clone https://github.com/wiremind/super-bowl.git - $ npm install - $ npm run serve - -.. _Node.js: https://nodejs.org/en/download/ - -Now, if you open ``localhost:8080`` in your browser, you will see the SuperBowl dashboard, but you will not see your messages yet. In order to see and manage them, you will have to modify the ``get_weather.py`` script to serve the remoulade api. - -.. code-block:: python - :caption: get_weather.py - :emphasize-lines: 4, 22, 23 - - import requests - import remoulade - from remoulade.brokers.rabbitmq import RabbitmqBroker - from remoulade.api.main import app - - - @remoulade.actor - def get_weather(city): - url = f"https://goweather.herokuapp.com/weather/{city}" - - response = requests.get(url).json() - - url_endpoint = - text = f'{city}: {response["description"]}' - requests.post(url_endpoint, json=text) - - - rabbitmq_broker = RabbitmqBroker() - remoulade.set_broker(rabbitmq_broker) - remoulade.declare_actors([get_weather]) - - if __name__ == "__main__": - app.run(host="localhost", port=5005) - -Now you can use the Enqueue tab to enqueue messages with custom arguments, and then see their progress in the messages tab. -Additionally, if you run groups or scheduled jobs in your script, you will be able to see them in their respective tabs. - Next Steps ---------- diff --git a/docs/source/global.rst b/docs/source/global.rst index a27b10ac4..416529144 100644 --- a/docs/source/global.rst +++ b/docs/source/global.rst @@ -88,5 +88,7 @@ .. _gevent: http://www.gevent.org/ .. _RabbitMQ: https://www.rabbitmq.com +.. _PostgreSQL: https://www.postgresql.org/ +.. _PGMQ: https://pgmq.github.io/pgmq/ .. _Redis: https://redis.io .. _Dramatiq: https://dramatiq.io diff --git a/docs/source/guide.rst b/docs/source/guide.rst index e786f7d63..0f41ed9c8 100644 --- a/docs/source/guide.rst +++ b/docs/source/guide.rst @@ -170,7 +170,7 @@ actor an invalid URL. Let's try it:: Message Retries ---------------- +^^^^^^^^^^^^^^^ If an error occurs during message processing, it will be terminated with a failure message. Alternatively, you can add the |Retries| Middleware to the broker and set the max_retries or retry_when option to automatically retry your message on failure. @@ -208,7 +208,7 @@ max_retries The maximum number of times a message should be retried. Default to ``0``. min_backoff -^^^^^^^^^^^ +^^^^^^^^^^^^^^^ The minimum number of milliseconds of backoff to apply between retries. Must be greater than 100 milliseconds. Defaults to 15 seconds. @@ -331,6 +331,11 @@ milliseconds):: Keep in mind that *your message broker is not a database*. Scheduled messages should represent a small subset of all your messages. +On brokers that emulate delay in worker memory, the enqueued message +will carry an ``eta`` option. ``PostgresBroker`` stores delayed messages in +PostgreSQL natively instead, so the message it returns does not include +``eta``. + Prioritizing Messages --------------------- @@ -381,7 +386,7 @@ Message Brokers --------------- Remoulade abstracts over the notion of a message broker and currently -supports RabbitMQ out of the box. +supports RabbitMQ and PostgreSQL/PGMQ out of the box. RabbitMQ Broker ^^^^^^^^^^^^^^^ @@ -398,6 +403,51 @@ execution:: remoulade.set_broker(rabbitmq_broker) +Postgres Broker +^^^^^^^^^^^^^^^ + +To configure PostgreSQL/PGMQ, install ``remoulade[postgres]`` and +instantiate a ``PostgresBroker`` with a PostgreSQL URL as early as possible +during your program's execution. This broker must be used with a PostgreSQL +user allowed to create and delete tables:: + + import remoulade + + from remoulade.brokers.postgres import PostgresBroker + + postgres_broker = PostgresBroker(url="postgresql://remoulade@localhost:5432/remoulade") + remoulade.set_broker(postgres_broker) + +PGMQ handles delayed messages natively, so ``send_with_options(delay=...)`` +does not create a worker-side delay queue or add an ``eta`` option to +the message. + +Each queue is created as a **partitioned** PGMQ queue (through +``pgmq.create_partitioned``), which relies on the PostgreSQL ``pg_partman`` +extension. Messages are stored in time-based partitions of the queue table, +and once a message is acked or nacked it is moved to the queue's archive +table, which is partitioned the same way. Two broker parameters control this: + +``archive_partition_interval_in_days`` (default ``1``) + The time span covered by a single partition. Smaller values create more, + smaller partitions; larger values create fewer, larger ones. + +``archive_retention_interval_in_days`` (default ``7``) + How long a partition is kept before ``pg_partman`` drops it. Archived + messages older than this window are removed together with their partition, + so set it comfortably above your longest expected processing and retry + window. + +.. note:: + + Partitioning is maintained by ``pg_partman``: its maintenance routine + (``partman.run_maintenance_proc()``) must run periodically — through the + ``pg_partman`` background worker or a ``pg_cron`` job — to create upcoming + partitions ahead of time and drop expired ones. The official PGMQ Docker + image (``ghcr.io/pgmq/pg18-pgmq``) ships ``pg_partman`` and schedules this + for you; on a self-managed or hosted PostgreSQL you must enable it yourself. + + Local Broker ^^^^^^^^^^^^^^^ @@ -472,4 +522,3 @@ synchronously by calling them as you would normal functions. .. _pytest fixtures: https://docs.pytest.org/en/latest/fixture.html .. _priority documentation: https://www.rabbitmq.com/priority.html - diff --git a/docs/source/index.rst b/docs/source/index.rst index 4a1674893..a0e19d881 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -49,9 +49,14 @@ If you want to use it with RabbitMQ_:: $ pip install -U 'remoulade[rabbitmq]' -Or if you want to use it with Redis_:: +Or if you want to use it with PostgreSQL_ and PGMQ_:: - $ pip install -U 'remoulade[redis]' + $ pip install -U 'remoulade[postgres]' + +Or if you want to use Redis_ for results and cancellation:: + + $ pip install -U 'remoulade[rabbitmq, redis]' + $ pip install -U 'remoulade[postgres, redis]' Read the :doc:`guide` if you're ready to get started. diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 1057dc0c3..ac7c77bea 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -9,16 +9,21 @@ Remoulade supports Python versions 3.12 and up and is installable via Via pip ------- +remoulade can be used with a RabbbitMQ_ or a PostgreSQL_ broker. -To install remoulade, simply run the following command in a terminal:: +If you want to use it with RabbitMQ_, simply run the following command in a terminal:: $ pip install -U 'remoulade[rabbitmq]' -Remoulade use RabbitMQ_ as message broker. +If you want to use PostgreSQL_ with PGMQ_ instead, install:: -If you would like to use it with Redis_ to store the results then run: + $ pip install -U 'remoulade[postgres]' + +If you would like to use Redis_-backed extras like results or +cancellation, add the ``redis`` extra to whichever broker you choose:: $ pip install -U 'remoulade[rabbitmq, redis]' + $ pip install -U 'remoulade[postgres, redis]' If you don't have `pip`_ installed, check out `this guide`_. @@ -32,6 +37,7 @@ extra requirements: Name Description ============= ======================================================================================= ``rabbitmq`` Installs the required dependencies for using Remoulade with RabbitMQ. +``postgres`` Installs the required dependencies for using Remoulade with PostgreSQL and PGMQ. ``redis`` Installs the required dependencies for using Remoulade with Redis. ============= ======================================================================================= diff --git a/docs/source/reference.rst b/docs/source/reference.rst index ab5123adb..43382c171 100644 --- a/docs/source/reference.rst +++ b/docs/source/reference.rst @@ -62,6 +62,9 @@ Brokers .. autoclass:: remoulade.brokers.rabbitmq.RabbitmqBroker :members: :inherited-members: +.. autoclass:: remoulade.brokers.postgres.PostgresBroker + :members: + :inherited-members: .. autoclass:: remoulade.brokers.stub.StubBroker :members: :inherited-members: diff --git a/examples/composition/composition/actors.py b/examples/composition/composition/actors.py index 58ad693ef..a3ab5c607 100644 --- a/examples/composition/composition/actors.py +++ b/examples/composition/composition/actors.py @@ -8,7 +8,7 @@ from remoulade.results import Results from remoulade.results.backends import RedisBackend from remoulade.state import MessageState -from remoulade.state.backends import PostgresBackend +from remoulade.state.backends import RedisBackend as RedisStateBackend encoder = PickleEncoder() backend = RedisBackend(encoder=encoder) @@ -16,7 +16,7 @@ broker.add_middleware(Results(backend=backend)) remoulade.set_broker(broker) remoulade.set_encoder(encoder) -remoulade.get_broker().add_middleware(MessageState(backend=PostgresBackend())) +remoulade.get_broker().add_middleware(MessageState(backend=RedisStateBackend())) @remoulade.actor(store_results=True) diff --git a/pyproject.toml b/pyproject.toml index 8a173cfe9..bf279d855 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,8 +23,8 @@ classifiers = [ rabbitmq = ["amqpstorm>=2.6,<3"] redis = ["redis>=7.0.0"] server = ["flask~=2.3.3", "marshmallow>=3,<4", "flask-apispec"] -postgres = ["sqlalchemy>=1.4.29,<2", "psycopg2>=2.9.11"] -pydantic = ["pydantic>=2.12", "simplejson"] +postgres = ["sqlalchemy>=2.0,<3", "psycopg>=3.2", "pgmq[sqlalchemy]>=1.1.1,<2"] +pydantic = ["pydantic>=2.12"] limits = ["limits~=5.3.0"] tracing = ["opentelemetry-api>=1.20"] @@ -34,10 +34,10 @@ dev = [ "flask~=2.3.3", "marshmallow>=3,<4", "flask-apispec", - "sqlalchemy>=1.4.29,<2", - "psycopg2>=2.9.11", + "sqlalchemy>=2.0,<3", + "pgmq[sqlalchemy]>=1.1.1", + "psycopg>=3.2", "pydantic>=2.12", - "simplejson", "limits~=5.3.0", # Docs "alabaster", @@ -48,10 +48,9 @@ dev = [ # Linting "ruff", "mypy~=1.18.2", - "sqlalchemy[mypy]>=1.4.29,<2", + "sqlalchemy[mypy]>=2.0,<3", "types-redis", "types-python-dateutil", - "types-simplejson", "types-requests", # Misc "pre-commit", @@ -97,10 +96,6 @@ include = ["remoulade*"] testpaths = ["tests"] asyncio_mode = "auto" markers = ["confirm_delivery", "group_transaction"] -filterwarnings = [ - "error::sqlalchemy.exc.RemovedIn20Warning", - "error::sqlalchemy.exc.MovedIn20Warning", -] [tool.ruff] target-version = "py312" diff --git a/remoulade/__init__.py b/remoulade/__init__.py index 4d193704c..78da3d860 100644 --- a/remoulade/__init__.py +++ b/remoulade/__init__.py @@ -33,6 +33,7 @@ QueueJoinTimeout, QueueNotFound, RemouladeError, + UnsupportedMessageEncoding, ) from .generic import GenericActor from .logging import get_logger @@ -71,6 +72,7 @@ # Errors "RemouladeError", "Result", + "UnsupportedMessageEncoding", # Workers "Worker", "actor", diff --git a/remoulade/api/main.py b/remoulade/api/main.py index b0a9831d2..c5d216a4e 100644 --- a/remoulade/api/main.py +++ b/remoulade/api/main.py @@ -1,7 +1,7 @@ """This file describe the API to get the state of messages""" import sys -from typing import Any, TypedDict +from typing import Any from flask import Flask from flask_apispec import marshal_with @@ -114,7 +114,7 @@ def get_results(message_id): max_size = 1e4 try: result = Result[Any](message_id=message_id).get() - encoded_result = get_encoder().encode(result).decode("utf-8") + encoded_result = get_encoder().encode_in_bytes(result).decode("utf-8") size_result = sys.getsizeof(encoded_result) if size_result >= max_size: encoded_result = f"The result is too big {size_result / 1e6}M" @@ -137,11 +137,6 @@ def enqueue_message(**kwargs): return {"result": "ok"} -class GroupMessagesT(TypedDict): - group_id: str - messages: list[dict] - - @app.route("/actors") @marshal_with(ActorResponseSchema) def get_actors(): diff --git a/remoulade/api/scheduler.py b/remoulade/api/scheduler.py index d3dfbd6c1..9f17c03fd 100644 --- a/remoulade/api/scheduler.py +++ b/remoulade/api/scheduler.py @@ -108,16 +108,6 @@ def delete_job(scheduler, job_hash): scheduler.delete_job(job_hash) -@scheduler_bp.route("/jobs/", methods=["PUT"]) -@doc(tags=["scheduler"]) -@marshal_with(ScheduleResponseSchema) -@validate_schema(ScheduledJobSchema) -@with_scheduler -def update_job(scheduler, job_hash, **kwargs): - scheduler.delete_job(job_hash) - scheduler.add_job(ScheduledJob(**kwargs)) - - @scheduler_bp.route("/jobs", methods=["PUT"]) @doc(tags=["scheduler"]) @marshal_with(ScheduleResponseSchema) @@ -129,4 +119,4 @@ def update_jobs(scheduler, **kwargs): scheduler.add_job(ScheduledJob(**job_dict)) -scheduler_routes = [get_jobs, add_job, delete_job, update_job, update_jobs] +scheduler_routes = [get_jobs, add_job, delete_job, update_jobs] diff --git a/remoulade/api/state.py b/remoulade/api/state.py index 4bb82f086..4982e5d1f 100644 --- a/remoulade/api/state.py +++ b/remoulade/api/state.py @@ -5,22 +5,12 @@ from remoulade import get_broker from remoulade.state import State, StateStatusesEnum -from remoulade.state.backends import PostgresBackend from .apispec import validate_schema messages_bp = Blueprint("messages", __name__, url_prefix="/messages") -class DeleteSchema(Schema): - """ - Class to validate delete body data in /messages/states - """ - - max_age = fields.Int(allow_none=True) - not_started = fields.Bool(load_default=False) - - class StatesParamsSchema(Schema): """ Class to validate the state search parameters @@ -81,17 +71,6 @@ def get_states(**kwargs): return {"data": data, "count": backend.get_states_count(**kwargs)} -@messages_bp.route("/states", methods=["DELETE"]) -@doc(tags=["state"]) -@validate_schema(DeleteSchema) -def clean_states(**kwargs): - backend = get_broker().get_state_backend() - if not isinstance(backend, PostgresBackend): - return {"error": "deleting states is only supported by the PostgresBackend"}, 400 - get_broker().get_state_backend().clean(**kwargs) - return {"result": "ok"} - - @messages_bp.route("/states/") @doc(tags=["state"]) @marshal_with(StateSchema) @@ -103,4 +82,4 @@ def get_state(message_id): return data.as_dict() -messages_routes = [get_states, clean_states, get_state] +messages_routes = [get_states, get_state] diff --git a/remoulade/broker.py b/remoulade/broker.py index 52bfa97f3..8732b6d7e 100644 --- a/remoulade/broker.py +++ b/remoulade/broker.py @@ -194,6 +194,8 @@ class Broker: overwrite when they are declared. """ + supports_native_delay = False + def __init__(self, middleware: "Iterable[Middleware] | None" = None): self.logger = get_logger(__name__, type(self)) self.actors: dict[str, Actor] = {} @@ -420,7 +422,11 @@ def declare_queue(self, queue_name: str) -> None: # pragma: no cover raise NotImplementedError def _apply_delay(self, message: "Message", delay: int | None = None) -> "Message": - raise NotImplementedError + """If your broker doesn't support native delay, you need to override this method""" + if self.supports_native_delay: + return message + else: + raise NotImplementedError("delay is not supported natively, you need to implement it") def _enqueue(self, message: "Message", *, delay: int | None = None) -> "Message": raise NotImplementedError diff --git a/remoulade/brokers/postgres.py b/remoulade/brokers/postgres.py new file mode 100644 index 000000000..572fe76cc --- /dev/null +++ b/remoulade/brokers/postgres.py @@ -0,0 +1,780 @@ +# This file is a part of Remoulade. +# +# Copyright (C) 2026 WIREMIND SAS +# +# Remoulade is free software; you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or (at +# your option) any later version. +# +# Remoulade is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . +import logging +import math +import time +from collections import deque +from collections.abc import Iterable, Iterator +from contextlib import contextmanager +from datetime import UTC, datetime, timedelta +from threading import Event, Lock, Thread, local +from typing import TYPE_CHECKING, Any, Final, override +from urllib.parse import urlparse + +import psycopg +from pgmq import SQLAlchemyPGMQueue +from pgmq.messages import Message as PostgresQueueMessage +from psycopg import sql as psycopg_sql +from sqlalchemy import Connection, text + +from ..broker import Broker, Consumer, MessageProxy +from ..errors import QueueJoinTimeout, QueueNotFound, UnsupportedMessageEncoding +from ..message import Message + +if TYPE_CHECKING: + from sqlalchemy import Engine + + from ..middleware import Middleware + +PostgresPayload = dict[str, Any] +LISTEN_NOTIFY_THROTTLE_MS: Final[int] = 250 +LISTENER_RECONNECT_BACKOFF_MIN_S: Final[float] = 0.5 +LISTENER_RECONNECT_BACKOFF_MAX_S: Final[float] = 30.0 + + +def _milliseconds_to_seconds(milliseconds: int) -> int: + """Convert milliseconds to whole seconds, rounding up for PGMQ inputs.""" + return max(1, math.ceil(milliseconds / 1000)) + + +class PostgresBroker(Broker): + """A broker backed by PostgreSQL via the PGMQ extension. + + PGMQ handles delayed messages natively, so delayed sends stay in + PostgreSQL instead of being staged in worker memory. + + Connection budget (per worker process): + * the shared SQLAlchemy pool (``pool_size``), used by reads, acks and + the heartbeat, bounded regardless of the number of consumed queues; + * a single shared LISTEN/NOTIFY connection (one per process, not one per + queue) when ``enable_listen_notify`` is True. + + Set ``enable_listen_notify=False`` for a poll-only mode that opens no + dedicated listener connection. This is required behind a connection pooler + in transaction pooling mode (e.g. pgbouncer), where LISTEN/NOTIFY is not + supported, and useful to cap connections on very large fan-outs. + """ + + supports_native_delay = True + + def __init__( + self, + *, + url: str, + middleware: list["Middleware"] | None = None, + group_transaction: bool = False, + archive_partition_interval_in_days: int = 1, + archive_retention_interval_in_days: int = 7, + visibility_timeout_ms: int = 30_000, + heartbeat_interval_ms: int = 10_000, + enable_listen_notify: bool = True, + pool_size: int = 10, + engine: "Engine | None" = None, + ) -> None: + """Initialize a PostgreSQL-backed broker using the PGMQ extension. + + Parameters: + url(str): PostgreSQL URL in plain format (`postgresql://...`), used both by SQLAlchemy and psycopg. + The url must be the creds for a user who can create and delete tables + middleware(list[Middleware] | None): Middleware stack applied to this broker. + group_transaction(bool): If True, wraps group and pipeline operations in a single transaction. + archive_partition_interval_in_days(int): Partition interval passed to PGMQ when creating partitioned queues. + archive_retention_interval_in_days(int): Retention interval passed to PGMQ for archive partitions. + visibility_timeout_ms(int): Message visibility timeout in milliseconds after read; must be greater than 0. + heartbeat_interval_ms(int): Heartbeat interval in milliseconds used to extend in-flight message visibility + must be greater than 0 and lower than visibility_timeout_ms. + enable_listen_notify(bool): If True (default), consumers are woken by a single process-wide LISTEN/NOTIFY + connection. If False, consumers poll only and no dedicated listener connection is opened (required behind + a transaction-pooling connection pooler such as pgbouncer). + pool_size(int): Size of the shared SQLAlchemy connection pool. Ignored when ``engine`` is provided. + engine(Engine | None): A pre-configured SQLAlchemy engine to reuse instead of letting PGMQ build one, so + the pool can be sized and shared by the caller. + """ + super().__init__(middleware=middleware) + if visibility_timeout_ms <= 0: + raise ValueError("visibility_timeout_ms must be greater than 0") + if heartbeat_interval_ms <= 0 or heartbeat_interval_ms >= visibility_timeout_ms: + raise ValueError("heartbeat_interval_ms must be greater than 0 and lower than visibility_timeout_ms") + + self.url = urlparse(url).geturl() + self.state = local() + self.group_transaction = group_transaction + self.archive_partition_interval = self.convert_days_in_partman_syntax(archive_partition_interval_in_days) + self.archive_retention_interval = self.convert_days_in_partman_syntax(archive_retention_interval_in_days) + self.visibility_timeout_ms = visibility_timeout_ms + self.heartbeat_interval_ms = heartbeat_interval_ms + self.visibility_timeout_seconds = _milliseconds_to_seconds(visibility_timeout_ms) + self.heartbeat_interval_seconds = heartbeat_interval_ms / 1000 + self.enable_listen_notify = enable_listen_notify + + self.client = SQLAlchemyPGMQueue( + conn_string=url, + init_extension=False, + vt=self.visibility_timeout_seconds, + pool_size=pool_size, + engine=engine, + ) + + self._listener = _PostgresListener(self.url, self.logger) if enable_listen_notify else None + + @staticmethod + def convert_days_in_partman_syntax(interval_in_day: int) -> str: + """Convert int into partman syntax""" + if interval_in_day <= 0: + raise ValueError("interval_in_day must be greater than 0") + if interval_in_day == 1: + return "1 day" + return f"{interval_in_day} days" + + @override + @contextmanager + def tx(self) -> Iterator[None]: + """Run broker operations inside a single SQL transaction. + + The active SQLAlchemy connection is stored on thread-local state so + queue operations can reuse it while the context is open. + """ + with self.client.engine.begin() as connection: + self.state.transaction_connection = connection + try: + yield + finally: + self.state.transaction_connection = None + + @property + def _current_connection(self) -> Connection | None: + """Return the transactional connection, if one is active.""" + return getattr(self.state, "transaction_connection", None) + + def _try_enable_notify(self, queue_name: str) -> None: + """Try to enable LISTEN/NOTIFY for a queue. + + Failures are logged and the consumer later falls back to polling. + """ + try: + self.client.enable_notify( + queue_name, + throttle_interval_ms=LISTEN_NOTIFY_THROTTLE_MS, + conn=self._current_connection, + ) + except Exception as e: + self.logger.warning( + "Failed to enable LISTEN/NOTIFY for queue %s; consumer will fall back to polling. Error: %s", + queue_name, + e, + exc_info=True, + ) + + def _queue_exists(self, queue_name: str) -> bool: + """Return whether the queue already exists in PostgreSQL.""" + return queue_name in {queue.queue_name for queue in self.client.list_queues()} + + @override + def close(self) -> None: + """Stop the shared listener and dispose the underlying PGMQ client.""" + if self._listener is not None: + self._listener.close() + self.client.dispose() + + @override + def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> "_PostgresConsumer": + """Create a consumer for a declared queue. + + Raises: + QueueNotFound: If the queue has not been declared. + """ + if queue_name not in self.queues: + raise QueueNotFound(queue_name) + return _PostgresConsumer(self, queue_name=queue_name, prefetch=prefetch, timeout=timeout) + + @override + def declare_queue(self, queue_name: str) -> None: + """Create a partitioned PGMQ queue if it does not already exist.""" + if queue_name in self.queues: + return + with self.tx(): + if self._current_connection is None: + raise ValueError("cannot be None we are inside a tx") + self._current_connection.execute( + text("SELECT pg_advisory_xact_lock(hashtext(:k))"), {"k": f"remoulade.declare_queue.{queue_name}"} + ) # Concurrency issues if broker try to create the same queues + self.client.validate_queue_name(queue_name, conn=self._current_connection) + queue_exists = self._queue_exists(queue_name) + + if not queue_exists: + self.emit_before("declare_queue", queue_name) + self.client.create_partitioned_queue( + queue_name, + partition_interval=self.archive_partition_interval, + retention_interval=self.archive_retention_interval, + conn=self._current_connection, + ) + if self.enable_listen_notify: + self._try_enable_notify(queue_name) + + self.queues[queue_name] = None + + if not queue_exists: + self.emit_after("declare_queue", queue_name) + + def _encode_message(self, message: "Message") -> PostgresPayload: + """Encode a Remoulade message into a JSON object payload for PGMQ. + + The encoder is responsible for validating that the payload is valid + JSON; here we only enforce the PGMQ-specific requirement that it be a + JSON object. + + Raises: + UnsupportedMessageEncoding: If the encoded payload is not valid JSON + or is not a JSON object. + """ + try: + payload = message.encode_in_json() + except (TypeError, ValueError, UnsupportedMessageEncoding) as exc: + raise UnsupportedMessageEncoding("PGMQ messages must contain JSON objects.") from exc + + if not isinstance(payload, dict): + raise UnsupportedMessageEncoding("PGMQ messages must contain JSON objects.") + + return payload + + @override + def _enqueue(self, message: "Message", *, delay: int | None = None) -> "Message": + """Send a message to PGMQ, optionally delayed by milliseconds.""" + if message.queue_name not in self.queues: + raise QueueNotFound(message.queue_name) + + payload = self._encode_message(message) + visible_at = datetime.now(UTC) + timedelta(milliseconds=delay) if delay is not None else None + self.client.send( + message.queue_name, + payload, + conn=self._current_connection, + delay=visible_at, + ) + + return message + + @override + def flush(self, queue_name: str) -> None: + """Remove every message currently stored in a queue.""" + if queue_name not in self.queues: + raise QueueNotFound(queue_name) + + self.client.purge(queue_name, conn=self._current_connection) + + @override + def flush_all(self) -> None: + """Purge every declared queue.""" + for queue_name in self.queues: + self.flush(queue_name) + + def _count_enqueued_messages(self, queue_name: str) -> int: + """Count every message stored in the queue.""" + return self.client.metrics(queue_name).queue_length + + @override + def join( + self, + queue_name: str, + min_successes: int = 10, + idle_time: int = 100, + *, + timeout: int | None = None, + ) -> None: + """Wait for all the messages on the given queue to be processed. + + This method checks the full PGMQ queue table and therefore waits for + all states: visible messages, invisible in-flight messages and native + delayed messages. + + Parameters: + queue_name(str): The queue to wait on. + min_successes(int): The minimum number of times the queue should be + observed as empty. + idle_time(int): The number of milliseconds to wait between checks. + timeout(Optional[int]): The max amount of time, in milliseconds, to + wait on this queue. + + Raises: + QueueNotFound: If the given queue was never declared. + QueueJoinTimeout: When the timeout elapses. + """ + if queue_name not in self.queues: + raise QueueNotFound(queue_name) + + deadline = time.monotonic() + timeout / 1000 if timeout is not None else None + successes = 0 + + while successes < min_successes: + if deadline and time.monotonic() >= deadline: + raise QueueJoinTimeout(queue_name) + + total_messages = self._count_enqueued_messages(queue_name) + if total_messages == 0: + successes += 1 + if successes >= min_successes: + return + else: + successes = 0 + + time.sleep(idle_time / 1000) + + +class _PostgresListener: + """Process-wide LISTEN/NOTIFY dispatcher shared by all consumers of a broker. + + A single psycopg connection LISTENs on every consumed queue's + ``pgmq.q_.INSERT`` channel and a single background thread routes each + notification to the matching consumer's wake event. This keeps the number + of dedicated listener connections at one per process instead of one per + consumed queue. + + The connection is opened synchronously when the first consumer registers. + On any connection failure ``available`` flips to False and + consumers transparently fall back to polling, while the thread keeps + retrying (with capped backoff) to reopen the connection and re-LISTEN on + every registered channel. ``available`` flips back to True as soon as a + reconnection succeeds, so consumers resume LISTEN/NOTIFY automatically + instead of polling for the rest of the process lifetime. + """ + + def __init__(self, url: str, logger: logging.Logger) -> None: + self._url = url + self._logger = logger + self._lock = Lock() + self._stop = Event() + self._connection: psycopg.Connection[Any] | None = None + self._thread: Thread | None = None + self._started = False + # An Event rather than a bare bool so reads from consumer threads and + # writes from the dispatch thread are synchronized without relying on + # CPython's GIL making attribute access atomic. + self._available = Event() + # Several consumers may register on the same queue, so map each queue to + # the set of their wake events instead of a single one. + self._events: dict[str, set[Event]] = {} + self._channel_to_queue: dict[str, str] = {} + self._pending_listen: set[str] = set() + + @property + def available(self) -> bool: + """Whether the shared LISTEN/NOTIFY connection is currently usable.""" + return self._available.is_set() + + @staticmethod + def _channel_for(queue_name: str) -> str: + return f"pgmq.q_{queue_name}.INSERT" + + def register(self, queue_name: str, event: Event) -> None: + """Register a consumer's wake event and ensure its channel is listened to.""" + with self._lock: + self._events.setdefault(queue_name, set()).add(event) + self._channel_to_queue[self._channel_for(queue_name)] = queue_name + self._pending_listen.add(queue_name) + starting = not self._started + if starting: + self._started = True + # Start outside the lock: the initial connection attempt acquires the + # same lock to snapshot channels, and would otherwise deadlock. + if starting: + self._start() + + def unregister(self, queue_name: str, event: Event) -> None: + """Stop routing notifications to a consumer that is shutting down. + + Only the queue's channel routing is dropped once its last consumer + leaves, so a closing consumer never stops notifications for a sibling + still consuming the same queue. + """ + with self._lock: + events = self._events.get(queue_name) + if events is None: + return + events.discard(event) + if not events: + self._events.pop(queue_name, None) + self._channel_to_queue.pop(self._channel_for(queue_name), None) + self._pending_listen.discard(queue_name) + + def _start(self) -> None: + """Open the shared connection (best effort) and start the dispatch thread. + + The initial connection attempt is synchronous so ``available`` reflects + a real outcome by the time the first consumer starts reading. If it + fails, the dispatch thread keeps retrying with backoff, reopening the + connection and re-LISTENing on every registered channel, so a transient + database outage degrades consumers to polling instead of disabling + LISTEN/NOTIFY for the rest of the process lifetime. + """ + self._open_connection() + self._thread = Thread(target=self._run, name="postgres-listener", daemon=True) + self._thread.start() + + def _open_connection(self) -> bool: + """Open the shared connection and LISTEN on every registered channel. + + Called once synchronously when the listener starts and then from the + dispatch thread on every reconnection. Must not be called while holding + ``self._lock``, which it acquires to snapshot the registered channels. + Returns True when the connection is ready and all channels are listened + to, False when the connection could not be opened (the caller then backs + off and retries). + """ + connection = None + try: + connection = psycopg.connect(self._url, autocommit=True) + with self._lock: + # A fresh connection listens to nothing, so re-LISTEN every + # channel currently registered rather than only those pending + # since the last loop. + channels = list(self._channel_to_queue) + self._pending_listen.clear() + for channel in channels: + connection.execute(psycopg_sql.SQL("LISTEN {}").format(psycopg_sql.Identifier(channel))) + except Exception as e: + self._logger.warning( + "Failed to open shared LISTEN/NOTIFY connection; consumers will fall back to polling. Error: %s", + str(e), + exc_info=True, + ) + if connection is not None: + try: + connection.close() + except Exception: + self._logger.debug("Failed to close partially-opened listener connection", exc_info=True) + return False + self._connection = connection + self._available.set() + return True + + def _drop_connection(self) -> None: + """Mark the listener unavailable and discard the current connection. + + Waiting consumers are woken so they fall back to polling immediately + instead of blocking on their wake event for the full timeout. + """ + self._available.clear() + self._wake_all() + if self._connection is not None: + try: + self._connection.close() + except Exception as e: + self._logger.error("Failed to close shared listener connection: %s", str(e)) + self._connection = None + + def _drain_pending_listen(self) -> None: + """Issue LISTEN for queues registered since the last loop (listener thread only).""" + with self._lock: + pending = list(self._pending_listen) + self._pending_listen.clear() + if self._connection is None: + return + for queue_name in pending: + channel = self._channel_for(queue_name) + self._connection.execute(psycopg_sql.SQL("LISTEN {}").format(psycopg_sql.Identifier(channel))) + + def _wake_channel(self, channel: str) -> None: + """Wake every consumer registered for a notification channel.""" + with self._lock: + queue_name = self._channel_to_queue.get(channel) + events = list(self._events.get(queue_name, ())) if queue_name is not None else [] + for event in events: + event.set() + + def _wake_all(self) -> None: + """Wake every registered consumer (used on shutdown or listener failure).""" + with self._lock: + events = [event for queue_events in self._events.values() for event in queue_events] + for event in events: + event.set() + + def _run(self) -> None: + """Route notifications to consumer events, reconnecting on errors. + + When the connection is down it is reopened with an exponential backoff + capped at ``LISTENER_RECONNECT_BACKOFF_MAX_S``; the backoff resets on + every successful reconnection. Consumers poll while the connection is + unavailable and resume LISTEN/NOTIFY once it is restored. + """ + backoff = LISTENER_RECONNECT_BACKOFF_MIN_S + while not self._stop.is_set(): + if self._connection is None: + if not self._open_connection(): + self._stop.wait(backoff) + backoff = min(backoff * 2, LISTENER_RECONNECT_BACKOFF_MAX_S) + continue + backoff = LISTENER_RECONNECT_BACKOFF_MIN_S + try: + self._drain_pending_listen() + for notify in self._connection.notifies(timeout=0.5, stop_after=1): + self._wake_channel(notify.channel) + except Exception as e: + if self._stop.is_set(): + break + self._logger.warning( + "Shared LISTEN/NOTIFY listener error; consumers will fall back to polling while it " + "reconnects. Error: %s", + str(e), + exc_info=True, + ) + self._drop_connection() + + self._drop_connection() + + def close(self) -> None: + """Stop the dispatch thread and close the shared connection.""" + self._stop.set() + self._wake_all() + if self._thread is not None: + self._thread.join(timeout=1) + self._drop_connection() + + +class _PostgresConsumer(Consumer): + def __init__(self, broker: PostgresBroker, *, queue_name: str, prefetch: int, timeout: int) -> None: + """Initialize a consumer for a PGMQ queue. + + Parameters: + broker(PostgresBroker): Broker instance that owns the queue and database client. + queue_name(str): Name of the declared queue to consume from. + prefetch(int): Maximum number of messages fetched per read call; must be greater than or equal to 1. + timeout(int): Idle wait timeout in milliseconds when polling for messages; must be greater than or equal to 0. + A value of 0 performs non-blocking reads. + + """ + if prefetch < 1: + raise ValueError("prefetch must be greater than or equal to 1") + if timeout < 0: + raise ValueError("timeout must be greater than or equal to 0") + + self.broker = broker + self.client = broker.client + self.queue_name = queue_name + self.prefetch = prefetch + self.timeout = timeout + self.visibility_timeout_seconds = broker.visibility_timeout_seconds + self.heartbeat_interval_seconds = broker.heartbeat_interval_seconds + self.messages: deque[PostgresQueueMessage] = deque() + self._notify_event = Event() + self._heartbeat_stop = Event() + self._heartbeat_thread: Thread | None = None + self._heartbeat_message_ids_lock = Lock() + self._heartbeat_message_ids: set[int] = set() + + if broker._listener is not None: + broker._listener.register(queue_name, self._notify_event) + self._start_heartbeat() + + @property + def wait_timeout_seconds(self) -> float: + """Return the listener wait timeout in seconds.""" + return self.timeout / 1000 if self.timeout > 0 else 0 + + @property + def _listener_available(self) -> bool: + """Whether the broker's shared LISTEN/NOTIFY listener is currently usable.""" + return self.broker._listener is not None and self.broker._listener.available + + def _normalize_messages( + self, messages: PostgresQueueMessage | list[PostgresQueueMessage] | None + ) -> list[PostgresQueueMessage]: + """Normalize a PGMQ read result into a list.""" + if messages is None: + return [] + return [messages] if not isinstance(messages, list) else messages + + def _read_immediate(self) -> list[PostgresQueueMessage]: + """Read up to ``prefetch`` messages without polling.""" + return self._normalize_messages( + self.client.read( + self.queue_name, + vt=self.visibility_timeout_seconds, + qty=self.prefetch, + ) + ) + + def _read_with_poll(self) -> list[PostgresQueueMessage]: + """Read up to ``prefetch`` messages using PGMQ polling. + + A zero timeout means non-blocking: do a single immediate read instead + of polling for the rounded-up one-second minimum. + """ + if self.timeout == 0: + return self._read_immediate() + return self._normalize_messages( + self.client.read_with_poll( + self.queue_name, + vt=self.visibility_timeout_seconds, + qty=self.prefetch, + max_poll_seconds=max((self.timeout + 999) // 1000, 1), + poll_interval_ms=max(min(self.timeout, 1000), 1), + ) + ) + + def _read_next_batch(self) -> list[PostgresQueueMessage]: + """Read the next batch, favoring LISTEN/NOTIFY when available.""" + if not self._listener_available: + return self._read_with_poll() + + messages = self._read_immediate() + if messages: + self._notify_event.clear() + return messages + + self._notify_event.wait(self.wait_timeout_seconds) + self._notify_event.clear() + return self._read_immediate() if self._listener_available else self._read_with_poll() + + def _start_heartbeat(self) -> None: + """Start the background visibility-timeout renewal thread.""" + self._heartbeat_thread = Thread( + target=self._run_heartbeat, + name=f"postgres-heartbeat-{self.queue_name}", + daemon=True, + ) + self._heartbeat_thread.start() + + def _run_heartbeat(self) -> None: + """Periodically renew the lease of in-flight messages. + + The loop sleeps for ``heartbeat_interval_seconds`` seconds, derived from the + millisecond broker input, unless shutdown is requested through + ``_heartbeat_stop``. On each tick it snapshots the current set of + tracked message ids under a lock, then calls ``set_vt`` to push their + visibility timeout back to ``self.visibility_timeout_seconds`` seconds. + + This prevents long-running handlers from becoming visible again and + being consumed twice before they are explicitly acked, nacked or + requeued. Failures are logged and retried on the next heartbeat tick. + """ + while not self._heartbeat_stop.wait(self.heartbeat_interval_seconds): + with self._heartbeat_message_ids_lock: + message_ids = list(self._heartbeat_message_ids) + if not message_ids: + continue + try: + self.client.set_vt(self.queue_name, message_ids, self.visibility_timeout_seconds) + except Exception as e: + self.broker.logger.warning( + "Failed to extend visibility timeout heartbeat for queue %s. Error: %s", + self.queue_name, + str(e), + exc_info=True, + ) + + def _unregister_heartbeat_message_id(self, message_id: int) -> None: + """Stop tracking a message for heartbeat renewal.""" + with self._heartbeat_message_ids_lock: + self._heartbeat_message_ids.discard(message_id) + + def _requeue_message_ids(self, message_ids: list[int]) -> None: + """Make a batch of message ids visible again immediately.""" + for message_id in message_ids: + self._unregister_heartbeat_message_id(message_id) + if len(message_ids) > 0: + self.client.set_vt(self.queue_name, message_ids, 0) + + def _archive_message(self, message: "MessageProxy") -> None: + """Stop tracking a message and archive it, tolerating transient failures. + + A failed archive (connection blip, pool exhaustion, ...) is logged and + swallowed rather than propagated: letting it bubble up would kill the + worker thread, which has no restart logic. The message simply becomes + visible again once its visibility timeout expires and is redelivered, + which is the broker's at-least-once guarantee. + """ + if not isinstance(message, _PostgresMessage): + raise ValueError("It must be a PostgresMessage") + self._unregister_heartbeat_message_id(message._postgres_message.msg_id) + try: + self.client.archive(self.queue_name, message._postgres_message.msg_id) + except Exception: + self.broker.logger.error( + "Failed to archive message %s on queue %s; it will be redelivered after its visibility timeout.", + message._postgres_message.msg_id, + self.queue_name, + exc_info=True, + ) + + @override + def ack(self, message: "MessageProxy") -> None: + """Archive a processed message.""" + self._archive_message(message) + + @override + def nack(self, message: "MessageProxy") -> None: + """Archive a failed message.""" + self._archive_message(message) + + @override + def requeue(self, messages: Iterable["MessageProxy"]) -> None: + """Make messages visible again immediately by resetting their visibility timeout.""" + message_ids = [ + message._postgres_message.msg_id for message in messages if isinstance(message, _PostgresMessage) + ] + self._requeue_message_ids(message_ids) + + def _build_message(self, postgres_message: PostgresQueueMessage) -> "_PostgresMessage": + """Wrap a raw PGMQ row as a Remoulade message proxy.""" + return _PostgresMessage(postgres_message) + + @override + def __next__(self) -> "_PostgresMessage | None": + """Return the next available message, or ``None`` if the queue stays empty.""" + if self.messages: + return self._build_message(self.messages.popleft()) + + messages = self._read_next_batch() + + if not messages: + return None + message_ids = [message.msg_id for message in messages] + with self._heartbeat_message_ids_lock: + self._heartbeat_message_ids.update(message_ids) + self.messages.extend(messages) + return self._build_message(self.messages.popleft()) + + @override + def close(self) -> None: + """Stop the heartbeat, unregister from the shared listener and requeue buffered messages.""" + self._heartbeat_stop.set() + self._notify_event.set() + if self.broker._listener is not None: + self.broker._listener.unregister(self.queue_name, self._notify_event) + if self._heartbeat_thread is not None: + self._heartbeat_thread.join(timeout=1) + self._requeue_message_ids([message.msg_id for message in self.messages]) + self.messages.clear() + + +class _PostgresMessage(MessageProxy): + def __init__(self, postgres_message: PostgresQueueMessage) -> None: + """Wrap a PGMQ message row as a Remoulade message proxy.""" + payload = postgres_message.message + if not isinstance(payload, dict): + raise UnsupportedMessageEncoding("PGMQ messages must contain JSON objects.") + try: + # Re-run the global message decoder so custom encoders (e.g. PydanticEncoder) + # can rehydrate actor args/kwargs to their typed schemas. + message = Message.decode_json(payload) + except TypeError as exc: + raise UnsupportedMessageEncoding("PGMQ message payload is not a valid Remoulade message envelope.") from exc + if message.options.get("eta", None) is not None: + raise UnsupportedMessageEncoding("eta option isn't supported with postgres broker") + super().__init__(message) + self._postgres_message = postgres_message diff --git a/remoulade/brokers/rabbitmq.py b/remoulade/brokers/rabbitmq.py index b60b1e299..844c4a180 100644 --- a/remoulade/brokers/rabbitmq.py +++ b/remoulade/brokers/rabbitmq.py @@ -21,7 +21,7 @@ from functools import partial from queue import Empty, Full, LifoQueue from threading import Lock, local -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, Final, override from amqpstorm import AMQPChannelError, AMQPConnectionError, AMQPError, Channel, UriConnection from amqpstorm.compatibility import urlparse @@ -181,6 +181,7 @@ def clear_channel_pools(self): self.channel_pools["confirm_delivery"].clear() self.channel_pools["no_confirm_delivery"].clear() + @override def close(self) -> None: """Close all open RabbitMQ connections.""" @@ -193,6 +194,7 @@ def close(self) -> None: self.clear_channel_pools() self.logger.debug("Channels and connections closed.") + @override def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 5000) -> "_RabbitmqConsumer": """Create a new consumer for a queue. @@ -227,6 +229,7 @@ def _declare_rabbitmq_queues(self): self._declare_dq_queue(channel, queue_name) self._declare_xq_queue(channel, queue_name) + @override def declare_queue(self, queue_name: str) -> None: """Declare a queue. Has no effect if a queue with the given name already exists. @@ -283,6 +286,7 @@ def _declare_xq_queue(self, channel, queue_name): arguments["x-queue-type"] = "quorum" return channel.queue.declare(queue=xq_name(queue_name), durable=True, arguments=arguments) + @override def _apply_delay(self, message: "Message", delay: int | None = None) -> "Message": if delay is not None: message_eta = current_millis() + delay @@ -291,6 +295,7 @@ def _apply_delay(self, message: "Message", delay: int | None = None) -> "Message return message + @override @contextmanager def tx(self): with self.get_channel_pool(confirm_delivery=False).acquire() as channel, channel.tx: @@ -312,6 +317,7 @@ def _get_channel(self, confirm_delivery: bool): with self.get_channel_pool(confirm_delivery).acquire() as channel: yield channel + @override def _enqueue(self, message: "Message", *, delay: int | None = None) -> "Message": """Enqueue a message. @@ -344,7 +350,7 @@ def _enqueue(self, message: "Message", *, delay: int | None = None) -> "Message" self.logger.debug("Enqueueing message %r on queue %r.", message.message_id, queue_name) with self._get_channel(confirm_delivery) as channel: confirmation = channel.basic.publish( - exchange="", routing_key=queue_name, body=message.encode(), properties=properties + exchange="", routing_key=queue_name, body=message.encode_in_bytes(), properties=properties ) if confirm_delivery and not confirmation: raise MessageNotDelivered("Message could not be delivered") @@ -393,6 +399,7 @@ def get_queue_message_counts(self, queue_name: str): xq_queue_response["message_count"], ) + @override def flush(self, queue_name: str) -> None: """Drop all the messages from a queue. @@ -406,11 +413,13 @@ def flush(self, queue_name: str) -> None: with self.default_channel_pool.acquire() as channel: channel.queue.purge(name) + @override def flush_all(self) -> None: """Drop all messages from all declared queues.""" for queue_name in self.queues: self.flush(queue_name) + @override def join( self, queue_name: str, min_successes: int = 10, idle_time: int = 100, *, timeout: int | None = None ) -> None: @@ -458,6 +467,7 @@ def __init__(self, connection, queue_name, prefetch, timeout): except (AMQPConnectionError, AMQPChannelError) as e: raise ConnectionClosed(e) from None + @override def ack(self, message): try: message.ack() @@ -466,6 +476,7 @@ def ack(self, message): except Exception: # pragma: no cover self.logger.error("Failed to ack message.", exc_info=True) + @override def nack(self, message): try: message.nack(requeue=False) @@ -474,11 +485,13 @@ def nack(self, message): except Exception: # pragma: no cover self.logger.error("Failed to nack message.", exc_info=True) + @override def requeue(self, messages): """RabbitMQ automatically re-enqueues unacked messages when consumers disconnect so this is a no-op. """ + @override def __next__(self): """Return None if no value after timeout seconds""" try: @@ -494,6 +507,7 @@ def __next__(self): except (AMQPConnectionError, AMQPChannelError) as e: raise ConnectionClosed(e) from None + @override def close(self): with suppress(AMQPConnectionError, AMQPChannelError): self.channel.close() @@ -501,7 +515,7 @@ def close(self): class _RabbitmqMessage(MessageProxy): def __init__(self, rabbitmq_message): - super().__init__(Message.decode(rabbitmq_message.body)) + super().__init__(Message.decode_bytes(rabbitmq_message.body)) self._rabbitmq_message = rabbitmq_message diff --git a/remoulade/brokers/stub.py b/remoulade/brokers/stub.py index 968151db3..7dd137f4f 100644 --- a/remoulade/brokers/stub.py +++ b/remoulade/brokers/stub.py @@ -99,7 +99,7 @@ def _enqueue(self, message, *, delay=None): if queue_name not in self.queues: raise QueueNotFound(queue_name) - self.queues[queue_name].put(message.encode()) + self.queues[queue_name].put(message.encode_in_bytes()) return message def flush(self, queue_name): @@ -165,7 +165,7 @@ def nack(self, message): def __next__(self): try: data = self.queue.get(timeout=self.timeout / 1000) - message = Message.decode(data) + message = Message.decode_bytes(data) return MessageProxy(message) except Empty: return None diff --git a/remoulade/encoder.py b/remoulade/encoder.py index 3c6f52653..126bfede1 100644 --- a/remoulade/encoder.py +++ b/remoulade/encoder.py @@ -19,15 +19,15 @@ import json import pickle import warnings -from typing import Annotated, Any, get_type_hints +from typing import Annotated, Any, get_type_hints, override + +from remoulade.errors import UnsupportedMessageEncoding try: - from pydantic import BaseModel, TypeAdapter, WithJsonSchema - from simplejson.decoder import JSONDecoder - from simplejson.encoder import JSONEncoder as _JSONEncoder + from pydantic import TypeAdapter, WithJsonSchema except ImportError: # pragma: no cover warnings.warn( - "Pydantic and simplejson are not available. Run `pip install remoulade[pydantic]`", + "Pydantic is not available. Run `pip install remoulade[pydantic]`", ImportWarning, stacklevel=2, ) @@ -35,30 +35,59 @@ #: Represents the contents of a Message object as a dict. MessageData = dict[str, Any] +JsonData = dict[str, Any] class Encoder(abc.ABC): """Base class for message encoders.""" @abc.abstractmethod - def encode(self, data: MessageData) -> bytes: # pragma: no cover - """Convert message metadata into a bytestring.""" + def encode_in_bytes(self, data: MessageData) -> bytes: raise NotImplementedError @abc.abstractmethod - def decode(self, data: bytes) -> MessageData: # pragma: no cover - """Convert a bytestring into message metadata.""" + def decode_bytes(self, data: bytes) -> MessageData: + raise NotImplementedError + + def encode_in_json(self, data: MessageData) -> JsonData: + encoded = self._encode_in_json(data) + try: + json.dumps(encoded) + except (TypeError, ValueError) as e: + raise UnsupportedMessageEncoding("This is not a valid json") from e + return encoded + + @abc.abstractmethod + def _encode_in_json(self, data: MessageData) -> JsonData: + raise NotImplementedError + + @abc.abstractmethod + def decode_json(self, data: JsonData) -> MessageData: raise NotImplementedError class JSONEncoder(Encoder): """Encodes messages as JSON. This is the default encoder.""" - def encode(self, data: MessageData) -> bytes: + @override + def encode_in_bytes(self, data: MessageData) -> bytes: + """Convert message metadata into a bytestring.""" + # Serialize directly: routing through encode_in_json would add a + # throwaway json.dumps validation pass on top of this one. return json.dumps(data, separators=(",", ":")).encode("utf-8") - def decode(self, data: bytes) -> MessageData: - return json.loads(data.decode("utf-8")) + @override + def decode_bytes(self, data: bytes) -> MessageData: + """Convert a bytestring into message metadata.""" + return self.decode_json(json.loads(data.decode("utf-8"))) + + @override + def _encode_in_json(self, data: MessageData) -> JsonData: + return data + + @override + def decode_json(self, data: JsonData) -> MessageData: + return data class PickleEncoder(Encoder): @@ -69,8 +98,21 @@ class PickleEncoder(Encoder): Use it at your own risk. """ - encode = pickle.dumps # type: ignore - decode = pickle.loads # type: ignore + @override + def encode_in_bytes(self, data: MessageData) -> bytes: + return pickle.dumps(data) + + @override + def decode_bytes(self, data: bytes) -> MessageData: + return pickle.loads(data) # noqa: S301 + + @override + def _encode_in_json(self, data: MessageData) -> JsonData: + raise TypeError("PickleEncoder does not support JSON encoding.") + + @override + def decode_json(self, data: JsonData) -> MessageData: + raise TypeError("PickleEncoder does not support JSON decoding.") class PydanticEncoder(Encoder): @@ -91,31 +133,36 @@ def my_actor(input_1: MyActorInputSchema, input_2: MyActorInputSchema | None = N def __init__(self, fallback_encoder: Encoder | None = None): self.fallback_encoder = fallback_encoder - self.json_encoder = _JSONEncoder(default=self.default) - self.json_decoder = JSONDecoder() + self.json_adapter = TypeAdapter(object) - @staticmethod - def default(o): - if isinstance(o, BaseModel): - # keep dict otherwise it will be serialized as a string (see Pydantic .json()) - return json.loads(o.model_dump_json()) - raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable") + @override + def encode_in_bytes(self, data: MessageData) -> bytes: + try: + return json.dumps(self._encode_in_json(data)).encode("utf-8") + except Exception: + if self.fallback_encoder is not None: + return self.fallback_encoder.encode_in_bytes(data) + raise - def encode(self, data: MessageData) -> bytes: + @override + def decode_bytes(self, data: bytes) -> MessageData: try: - return self.json_encoder.encode(data).encode("utf-8") - except Exception as e: + return self.decode_json(json.loads(data.decode("utf-8"))) + except Exception: if self.fallback_encoder is not None: - return self.fallback_encoder.encode(data) - else: - raise e + return self.fallback_encoder.decode_bytes(data) + raise + + @override + def _encode_in_json(self, data: MessageData) -> JsonData: + return self.json_adapter.dump_python(data, mode="json") - def decode(self, data: bytes) -> MessageData: + @override + def decode_json(self, data: JsonData) -> MessageData: from remoulade import get_broker try: - raw_message = self.json_decoder.decode(data.decode("utf-8")) - actor_name = raw_message["actor_name"] + actor_name = data["actor_name"] actor_fn = get_broker().get_actor(actor_name).fn # Retrieve the Pydantic schemas from typing @@ -130,9 +177,9 @@ def decode(self, data: bytes) -> MessageData: ] ) - # Override message_data with Pydantic schema when it matches + # Override message_data with Pydantic schema when it matches. parsed_message: dict[str, Any] = {} - for key, values in raw_message.items(): + for key, values in data.items(): if key == "kwargs": if not isinstance(values, dict): raise TypeError(f"Expected `values` to be a dict, got {type(values).__name__}") @@ -156,8 +203,7 @@ def decode(self, data: bytes) -> MessageData: parsed_message[key] = values return parsed_message - except Exception as e: + except Exception: if self.fallback_encoder is not None: - return self.fallback_encoder.decode(data) - else: - raise e + return self.fallback_encoder.decode_json(data) + raise diff --git a/remoulade/errors.py b/remoulade/errors.py index 271d06d24..10e4dda46 100644 --- a/remoulade/errors.py +++ b/remoulade/errors.py @@ -64,6 +64,10 @@ class MessageNotDelivered(ConnectionError): """Raised when a message has not been delivered.""" +class UnsupportedMessageEncoding(BrokerError): + """Raised when the current message encoder does not produce JSON compatible payloads.""" + + class NoResultBackend(BrokerError): """Raised when trying to access the result backend on a broker without it""" diff --git a/remoulade/message.py b/remoulade/message.py index 516203cc0..5b3836c74 100644 --- a/remoulade/message.py +++ b/remoulade/message.py @@ -89,13 +89,22 @@ def asdict(self): return attr.asdict(self, recurse=False) @classmethod - def decode(cls, data): + def decode_bytes(cls, data: bytes): """Convert a bytestring to a message.""" - return cls(**global_encoder.decode(data)) + return cls(**global_encoder.decode_bytes(data)) - def encode(self): + def encode_in_bytes(self): """Convert this message to a bytestring.""" - return global_encoder.encode(self.asdict()) + return global_encoder.encode_in_bytes(self.asdict()) + + @classmethod + def decode_json(cls, data: dict[str, Any]): + """Convert a JSON object to a message.""" + return cls(**global_encoder.decode_json(data)) + + def encode_in_json(self): + """Convert this message to a JSON object.""" + return global_encoder.encode_in_json(self.asdict()) def copy(self, **attributes): """Create a copy of this message.""" diff --git a/remoulade/results/backends/redis.py b/remoulade/results/backends/redis.py index 9bc606ee4..f3895b73c 100644 --- a/remoulade/results/backends/redis.py +++ b/remoulade/results/backends/redis.py @@ -88,7 +88,7 @@ def get_results( for message_id in message_ids: message_key = self.build_message_key(message_id) if forget: - pipe.rpushx(message_key, self.encoder.encode(ForgottenResult.asdict())) + pipe.rpushx(message_key, self.encoder.encode_in_bytes(ForgottenResult.asdict())) pipe.lpop(message_key) else: pipe.rpoplpush(message_key, message_key) @@ -98,7 +98,7 @@ def get_results( continue # skip one row in two if forget as there is two commands if row is None: raise ResultMissing(message_id) - yield self.process_result(BackendResult(**self.encoder.decode(row)), raise_on_error) + yield self.process_result(BackendResult(**self.encoder.decode_bytes(row)), raise_on_error) def _brpoplpush_with_timeout(self, src, dst, timeout: int): """ @@ -150,7 +150,7 @@ async def async_get_result( if timeout > 0: if forget: async with self.async_client.pipeline() as pipe: - await pipe.rpushx(message_key, self.encoder.encode(ForgottenResult.asdict())) + await pipe.rpushx(message_key, self.encoder.encode_in_bytes(ForgottenResult.asdict())) await pipe.lpop(message_key) pipe_exec = await pipe.execute() data = pipe_exec[1] @@ -176,7 +176,7 @@ async def async_get_result( timeout = int(deadline - time.monotonic()) if timeout <= 0: # do not retry is timeout is expired raise - result = BackendResult(**self.encoder.decode(data)) + result = BackendResult(**self.encoder.decode_bytes(data)) return self.process_result(result, raise_on_error) def get_result(self, message_id: str, *, block=False, timeout=None, forget=False, raise_on_error=True): @@ -214,14 +214,14 @@ def get_result(self, message_id: str, *, block=False, timeout=None, forget=False data = self._brpoplpush_with_timeout(message_key, message_key, timeout=timeout) if forget and data is not None: with self.client.pipeline() as pipe: - pipe.lpushx(message_key, self.encoder.encode(ForgottenResult.asdict())) + pipe.lpushx(message_key, self.encoder.encode_in_bytes(ForgottenResult.asdict())) pipe.ltrim(message_key, 0, 0) pipe.execute() else: if forget: with self.client.pipeline() as pipe: - pipe.rpushx(message_key, self.encoder.encode(ForgottenResult.asdict())) + pipe.rpushx(message_key, self.encoder.encode_in_bytes(ForgottenResult.asdict())) pipe.lpop(message_key) data = pipe.execute()[1] else: @@ -251,21 +251,21 @@ def get_result(self, message_id: str, *, block=False, timeout=None, forget=False if block and timeout <= 0: # do not retry is timeout is expired raise - result = BackendResult(**self.encoder.decode(data)) + result = BackendResult(**self.encoder.decode_bytes(data)) return self.process_result(result, raise_on_error) def _store(self, message_keys, results, ttl): with self.client.pipeline() as pipe: for message_key, result in zip(message_keys, results, strict=False): pipe.delete(message_key) - pipe.lpush(message_key, self.encoder.encode(result)) + pipe.lpush(message_key, self.encoder.encode_in_bytes(result)) pipe.pexpire(message_key, ttl) pipe.execute() def _get(self, key, forget=False): data = self.client.rpop(key) if forget else self.client.rpoplpush(key, key) if data: - return self.encoder.decode(data) + return self.encoder.decode_bytes(data) return Missing def _delete(self, key): diff --git a/remoulade/results/backends/stub.py b/remoulade/results/backends/stub.py index 21ab3147a..3d613a47b 100644 --- a/remoulade/results/backends/stub.py +++ b/remoulade/results/backends/stub.py @@ -36,17 +36,17 @@ def _get(self, message_key: str, forget: bool = False): if forget: data, expiration = self.results.get(message_key, (None, None)) if data is not None: - self.results[message_key] = self.encoder.encode(ForgottenResult.asdict()), expiration + self.results[message_key] = self.encoder.encode_in_bytes(ForgottenResult.asdict()), expiration else: data, expiration = self.results.get(message_key, (None, None)) if data is not None and time.monotonic() < expiration: - return self.encoder.decode(data) + return self.encoder.decode_bytes(data) return Missing def _store(self, message_keys, results, ttl): for message_key, result in zip(message_keys, results, strict=False): - result_data = self.encoder.encode(result) + result_data = self.encoder.encode_in_bytes(result) expiration = time.monotonic() + int(ttl / 1000) self.results[message_key] = (result_data, expiration) diff --git a/remoulade/scheduler/scheduler.py b/remoulade/scheduler/scheduler.py index 6332621d2..45095500d 100644 --- a/remoulade/scheduler/scheduler.py +++ b/remoulade/scheduler/scheduler.py @@ -72,9 +72,9 @@ def get_hash(self) -> str: str(self.iso_weekday), str(self.enabled), self.tz, - *(encoder.encode(arg).decode() for arg in self.args), + *(encoder.encode_in_bytes(arg).decode() for arg in self.args), *( - f"{name}: {encoder.encode(arg).decode()}" + f"{name}: {encoder.encode_in_bytes(arg).decode()}" for name, arg in sorted(self.kwargs.items(), key=itemgetter(0)) ), ] @@ -104,11 +104,11 @@ def as_dict(self, encode: bool = False) -> dict: return job_dict def encode(self) -> bytes: - return get_encoder().encode(self.as_dict(encode=True)) + return get_encoder().encode_in_bytes(self.as_dict(encode=True)) @classmethod def decode(cls, data: bytes) -> "ScheduledJob": - data = get_encoder().decode(data) + data = get_encoder().decode_bytes(data) return ScheduledJob( actor_name=data["actor_name"], interval=data["interval"], diff --git a/remoulade/state/backend.py b/remoulade/state/backend.py index 783265917..b044dd6ec 100644 --- a/remoulade/state/backend.py +++ b/remoulade/state/backend.py @@ -99,7 +99,7 @@ def as_dict(self, exclude_keys=(), encode_args=False): for key in (item for item in ["args", "kwargs", "options"] if item in as_dict): try: - as_dict[key] = get_encoder().encode(as_dict[key]).decode("utf-8") + as_dict[key] = get_encoder().encode_in_bytes(as_dict[key]).decode("utf-8") except (UnicodeDecodeError, TypeError): as_dict[key] = "encoded_data" return as_dict @@ -201,16 +201,16 @@ def _encode_dict(self, data): """Return the (keys, values) of a dictionary encoded""" encoded_data = {} for key, value in data.items(): - encoded_value = self.encoder.encode(value) + encoded_value = self.encoder.encode_in_bytes(value) if sys.getsizeof(encoded_value) <= self.max_size: - encoded_data[self.encoder.encode(key)] = self.encoder.encode(value) + encoded_data[self.encoder.encode_in_bytes(key)] = self.encoder.encode_in_bytes(value) return encoded_data def _decode_dict(self, data): """Return the (keys, values) of a dictionary decoded""" decoded_data = {} for key, value in data.items(): - decoded_data[self.encoder.decode(key)] = self.encoder.decode(value) + decoded_data[self.encoder.decode_bytes(key)] = self.encoder.decode_bytes(value) return decoded_data def clean(self, max_age: int | None = None, not_started: bool = False): diff --git a/remoulade/state/backends/__init__.py b/remoulade/state/backends/__init__.py index 384ce0295..567912bf6 100644 --- a/remoulade/state/backends/__init__.py +++ b/remoulade/state/backends/__init__.py @@ -1,14 +1,3 @@ -try: - from .postgres import PostgresBackend -except ImportError: # pragma: no cover - import warnings - - warnings.warn( - "PostgresBackend is not available. Run `pip install remoulade[postgres]` to add support for that backend.", - ImportWarning, - stacklevel=2, - ) - try: from .redis import RedisBackend from .stub import StubBackend @@ -21,4 +10,4 @@ stacklevel=2, ) -__all__ = ["PostgresBackend", "RedisBackend", "StubBackend"] +__all__ = ["RedisBackend", "StubBackend"] diff --git a/remoulade/state/backends/postgres.py b/remoulade/state/backends/postgres.py deleted file mode 100644 index eb8a85278..000000000 --- a/remoulade/state/backends/postgres.py +++ /dev/null @@ -1,250 +0,0 @@ -import datetime -import os -import sys -import threading -from typing import TypeVar - -from sqlalchemy import ( - Column, - DateTime, - Float, - LargeBinary, - SmallInteger, - String, - create_engine, - distinct, - inspect, - or_, - text, -) -from sqlalchemy.orm import declarative_base -from sqlalchemy.orm.session import sessionmaker -from sqlalchemy.sql import func -from sqlalchemy.sql.functions import coalesce, count, max, min # noqa: A004 - -from remoulade import Encoder -from remoulade.state import State, StateBackend - -Base = declarative_base() - -DEFAULT_POSTGRES_URI = "postgresql://remoulade@localhost:5432/remoulade" -DB_VERSION = 3 -T = TypeVar("T", bound="StoredState") - - -class StoredState(Base): - __tablename__ = "states" - - message_id = Column(String(length=36), primary_key=True, index=True) - status = Column(String(length=10), index=True) - actor_name = Column(String(length=79), index=True) - args = Column(LargeBinary) - kwargs = Column(LargeBinary) - options = Column(LargeBinary) - priority = Column(SmallInteger) - progress = Column(Float) - enqueued_datetime = Column(DateTime(timezone=True), index=True) - started_datetime = Column(DateTime(timezone=True), index=True) - end_datetime = Column(DateTime(timezone=True), index=True) - queue_name = Column(String(length=60)) - composition_id = Column(String) - - def as_state(self, encoder: Encoder) -> State: - state_dict = {} - mapper = inspect(StoredState) - for column in mapper.attrs: - column_value = getattr(self, column.key) - if column_value is None: - continue - if column.key in ["args", "kwargs", "options"]: - column_value = encoder.decode(column_value) - state_dict[column.key] = column_value - return State.from_dict(state_dict) - - @classmethod - def from_state(cls: type[T], state: State, max_size: int, encoder: Encoder) -> T: - state_dict = state.as_dict() - for key in ["args", "kwargs", "options"]: - if key in state_dict: - encoded_value = encoder.encode(state_dict[key]) - state_dict[key] = encoded_value if sys.getsizeof(encoded_value) <= max_size else None - return cls(**state_dict) - - -class StateVersion(Base): - __tablename__ = "version" - - version = Column(SmallInteger, primary_key=True) - - -def filter_query( - *, - query, - selected_actors: list[str] | None, - selected_statuses: list[str] | None, - selected_message_ids: list[str] | None, - selected_composition_ids: list[str] | None, - start_datetime: datetime.datetime | None, - end_datetime: datetime.datetime | None, -): - if selected_actors is not None: - query = query.filter(StoredState.actor_name.in_(selected_actors)) - if selected_statuses is not None: - query = query.filter(StoredState.status.in_(selected_statuses)) - if selected_message_ids is not None: - query = query.filter(StoredState.message_id.in_(selected_message_ids)) - if selected_composition_ids is not None: - query = query.filter(StoredState.composition_id.in_(selected_composition_ids)) - if start_datetime is not None: - query = query.filter(StoredState.enqueued_datetime >= start_datetime) - if end_datetime is not None: - query = query.filter(StoredState.enqueued_datetime <= end_datetime) - - return query - - -class PostgresBackend(StateBackend): - def __init__( - self, - *, - namespace: str = "remoulade-state", - encoder: Encoder | None = None, - client: sessionmaker | None = None, - max_size: int = 2000000, - url: str | None = None, - future: bool = False, - ): - self.url = url or os.getenv("REMOULADE_POSTGRESQL_URL") or DEFAULT_POSTGRES_URI - super().__init__(namespace=namespace, encoder=encoder) - self.client = client or sessionmaker(create_engine(self.url, pool_pre_ping=True, future=future)) - self.init_db() - self.max_size = max_size - self.lock = threading.Lock() - - def init_db(self): - with self.client.begin() as session: - bind = session.get_bind() - insp = inspect(bind) - - if not insp.has_table("version"): - Base.metadata.create_all(bind=bind, tables=[StateVersion.__table__]) - - state_version = session.query(StateVersion).first() - if state_version is None: - session.add(StateVersion(version=DB_VERSION)) - - if not insp.has_table("states"): - Base.metadata.create_all(bind=bind, tables=[StoredState.__table__]) - elif state_version is None or state_version.version != DB_VERSION: - StoredState.__table__.drop(bind=bind) - Base.metadata.create_all(bind=bind, tables=[StoredState.__table__]) - if state_version is not None: - state_version.version = DB_VERSION - - def get_state(self, message_id: str): - with self.client.begin() as session: - state = session.query(StoredState).filter_by(message_id=message_id).first() - if state is None: - return None - return state.as_state(self.encoder) - - def set_state(self, state: State, ttl=3600): - with self.lock, self.client.begin() as session: - session.merge(StoredState.from_state(state, self.max_size, self.encoder)) - - def get_states( - self, - *, - size: int | None = None, - offset: int = 0, - selected_actors: list[str] | None = None, - selected_statuses: list[str] | None = None, - selected_message_ids: list[str] | None = None, - selected_composition_ids: list[str] | None = None, - start_datetime: datetime.datetime | None = None, - end_datetime: datetime.datetime | None = None, - sort_column: str | None = None, - sort_direction: str | None = None, - ): - sort_column = sort_column or "enqueued_datetime" - sort_direction = sort_direction or "desc" - - with self.client.begin() as session: - query = session.query(StoredState) - query = filter_query( - query=query, - selected_actors=selected_actors, - selected_statuses=selected_statuses, - selected_message_ids=selected_message_ids, - selected_composition_ids=selected_composition_ids, - start_datetime=start_datetime, - end_datetime=end_datetime, - ) - if size is not None: - query = query.subquery() - query_group = ( - session.query( - max(query.c.composition_id).label("grouped_composition_id"), - max(query.c.message_id).label("grouped_message_id"), - max(query.c.status).label("grouped_status"), - max(query.c.actor_name).label("grouped_actor_name"), - max(query.c.priority).label("grouped_priority"), - func.avg(query.c.progress).label("grouped_progress"), - min(query.c.enqueued_datetime).label("grouped_enqueued_datetime"), - min(query.c.started_datetime).label("grouped_started_datetime"), - max(query.c.end_datetime).label("grouped_end_datetime"), - max(query.c.queue_name).label("grouped_queue_name"), - ) - .group_by(coalesce(query.c.composition_id, query.c.message_id)) - .order_by(text(f"grouped_{sort_column} {sort_direction}")) - ) - query_group = query_group.offset(offset).limit(size).subquery() - query = ( - session.query(StoredState) - .select_from(StoredState) - .join( - query_group, - or_( - StoredState.message_id == query_group.c.grouped_message_id, - StoredState.composition_id == query_group.c.grouped_composition_id, - ), - ) - ) - query = query.order_by(text(f"{sort_column} {sort_direction}")) - return [state_model.as_state(self.encoder) for state_model in query] - - def get_states_count( - self, - *, - selected_actors: list[str] | None = None, - selected_statuses: list[str] | None = None, - selected_messages_ids: list[str] | None = None, - selected_composition_ids: list[str] | None = None, - start_datetime: datetime.datetime | None = None, - end_datetime: datetime.datetime | None = None, - **kwargs, - ): - with self.client.begin() as session: - query = session.query(count(distinct(coalesce(StoredState.composition_id, StoredState.message_id)))) - - query = filter_query( - query=query, - selected_actors=selected_actors, - selected_statuses=selected_statuses, - selected_message_ids=selected_messages_ids, - selected_composition_ids=selected_composition_ids, - start_datetime=start_datetime, - end_datetime=end_datetime, - ) - return query.first()[0] - - def clean(self, max_age: int | None = None, not_started: bool = False): - with self.client.begin() as session: - query = session.query(StoredState) - if max_age: - now = datetime.datetime.now(datetime.UTC) - min_datetime = now - datetime.timedelta(minutes=max_age) - query = session.query(StoredState).filter(StoredState.end_datetime < min_datetime) - if not_started: - query = session.query(StoredState).filter(StoredState.started_datetime.is_(None)) - query.delete() diff --git a/tests/conftest.py b/tests/conftest.py index 1cbd25c11..e07b44846 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,14 +10,15 @@ import redis from freezegun import freeze_time from sqlalchemy.engine import create_engine -from sqlalchemy.inspection import inspect from sqlalchemy.orm.session import sessionmaker from sqlalchemy.pool import NullPool +from sqlalchemy.sql import text import remoulade from remoulade import Worker from remoulade.api import app from remoulade.brokers.local import LocalBroker +from remoulade.brokers.postgres import PostgresBroker from remoulade.brokers.rabbitmq import RabbitmqBroker from remoulade.brokers.stub import StubBroker from remoulade.cancel import backends as cl_backends @@ -32,7 +33,6 @@ MessageState, backends as st_backends, ) -from remoulade.state.backends.postgres import DB_VERSION, StateVersion logfmt = "[%(asctime)s] [%(threadName)s] [%(name)s] [%(levelname)s] %(message)s" logging.basicConfig(level=logging.INFO, format=logfmt) @@ -64,20 +64,21 @@ def check_redis(client): def check_postgres(client): - with client.begin() as session: - insp = inspect(session.get_bind()) - version_exists = insp.has_table("version") - states_exists = insp.has_table("states") - if version_exists: - versions = session.query(StateVersion).all() - return version_exists and states_exists and len(versions) == 1 and versions[0].version == DB_VERSION + try: + with client.begin() as session: + session.execute(text("CREATE SCHEMA IF NOT EXISTS partman")) + session.execute(text("CREATE EXTENSION IF NOT EXISTS pg_partman WITH SCHEMA partman")) + session.execute(text("CREATE EXTENSION IF NOT EXISTS pgmq")) + session.execute(text("SELECT pgmq.validate_queue_name('remoulade_test_queue')")) + except Exception as e: + raise e from e if CI else pytest.skip("No connection to PostgreSQL/PGMQ server.") -@pytest.fixture -def check_postgres_begin(): - client = remoulade.get_broker().get_state_backend().client - if not check_postgres(client): - pytest.skip("Postgres Database is not in the proper state. Database initialisation is probably incorrect.") +def cleanup_postgres_queues(client): + with client.begin() as session: + queue_names = [row[0] for row in session.execute(text("SELECT queue_name FROM pgmq.list_queues()")).all()] + for queue_name in queue_names: + session.execute(text("SELECT pgmq.drop_queue(:queue_name)"), {"queue_name": queue_name}) @pytest.fixture() @@ -113,6 +114,27 @@ def rabbitmq_broker(request): broker.close() +@pytest.fixture() +def postgres_broker(): + db_string = os.getenv("REMOULADE_TEST_DB_URL") or "postgresql://remoulade@localhost:5544/test" + # Force the psycopg (v3) driver for the raw SQLAlchemy engine: SQLAlchemy defaults + # `postgresql://` to psycopg2, which the project does not depend on. This mirrors the driver + # swap PGMQ does internally for the broker. The plain `db_string` is kept for PostgresBroker, + # whose listener opens a raw psycopg.connect() connection that does not understand `+psycopg`. + engine = create_engine(db_string.replace("postgresql://", "postgresql+psycopg://", 1), poolclass=NullPool) + client = sessionmaker(bind=engine) + check_postgres(client) + cleanup_postgres_queues(client) + + broker = PostgresBroker(url=db_string) + broker.emit_after("process_boot") + remoulade.set_broker(broker) + yield broker + cleanup_postgres_queues(client) + broker.emit_before("process_stop") + broker.close() + + @pytest.fixture() def local_broker(): broker = LocalBroker() @@ -200,19 +222,11 @@ def stub_state_backend(): @pytest.fixture -def postgres_state_backend(): - db_string = os.getenv("REMOULADE_TEST_DB_URL") or "postgresql://remoulade@localhost:5544/test" - backend = st_backends.PostgresBackend(client=sessionmaker(create_engine(db_string, poolclass=NullPool))) - backend.clean() - return backend +def state_backends(redis_state_backend, stub_state_backend): + return {"redis": redis_state_backend, "stub": stub_state_backend} -@pytest.fixture -def state_backends(postgres_state_backend, redis_state_backend, stub_state_backend): - return {"postgres": postgres_state_backend, "redis": redis_state_backend, "stub": stub_state_backend} - - -@pytest.fixture(params=["postgres", "redis", "stub"]) +@pytest.fixture(params=["redis", "stub"]) def state_backend(request, state_backends): return state_backends[request.param] @@ -225,14 +239,6 @@ def state_middleware(state_backend): return middleware -@pytest.fixture -def postgres_state_middleware(postgres_state_backend): - broker = remoulade.get_broker() - middleware = MessageState(backend=postgres_state_backend) - broker.add_middleware(middleware) - return middleware - - @pytest.fixture def result_middleware(result_backend): broker = remoulade.get_broker() diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index aaf02c9c8..05707d794 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -8,10 +8,12 @@ services: ports: - "5784:5672" postgres: - image: postgres + image: ghcr.io/pgmq/pg18-pgmq:v1.10.0 ports: - "5544:5432" environment: POSTGRES_USER: remoulade POSTGRES_HOST_AUTH_METHOD: trust - POSTGRES_DB: test \ No newline at end of file + POSTGRES_DB: test + volumes: + - ./docker/postgres/01-create-postgres-extension.sql:/docker-entrypoint-initdb.d/01-create-postgres-extension.sql:ro diff --git a/tests/docker/postgres/01-create-postgres-extension.sql b/tests/docker/postgres/01-create-postgres-extension.sql new file mode 100644 index 000000000..3afb0e3ff --- /dev/null +++ b/tests/docker/postgres/01-create-postgres-extension.sql @@ -0,0 +1,3 @@ +CREATE SCHEMA IF NOT EXISTS partman; +CREATE EXTENSION IF NOT EXISTS pg_partman WITH SCHEMA partman; +CREATE EXTENSION IF NOT EXISTS pgmq; diff --git a/tests/middleware/test_message_state.py b/tests/middleware/test_message_state.py index ee66000f9..fad062c3e 100644 --- a/tests/middleware/test_message_state.py +++ b/tests/middleware/test_message_state.py @@ -8,7 +8,6 @@ from remoulade.cancel import Cancel from remoulade.middleware import Middleware, SkipMessage from remoulade.state.backend import State, StateStatusesEnum -from remoulade.state.backends import PostgresBackend from remoulade.state.middleware import MessageState from tests.conftest import mock_func @@ -104,9 +103,6 @@ def before_process_message(self, broker, message): @pytest.mark.parametrize("ttl, result_type", [pytest.param(1000, State), pytest.param(1, type(None))]) def test_expiration_data_backend(self, ttl, result_type, stub_broker, state_backend): - if isinstance(state_backend, PostgresBackend): - pytest.skip("Skipping this test as there is no expiration on PostgresBackend") - @remoulade.actor def wait(): pass diff --git a/tests/state/test_backend.py b/tests/state/test_backend.py index 59c8b6972..5ea04c87e 100644 --- a/tests/state/test_backend.py +++ b/tests/state/test_backend.py @@ -1,7 +1,4 @@ -import pytest - from remoulade.state import State, StateStatusesEnum -from remoulade.state.backends import PostgresBackend class TestStateBackend: @@ -31,31 +28,3 @@ def test_count_messages(self, stub_broker, state_middleware): backend.set_state(State(f"id{i}")) assert backend.get_states_count() == 3 - - def test_count_compositions(self, stub_broker, state_middleware): - if not isinstance(state_middleware.backend, PostgresBackend): - pytest.skip() - - backend = state_middleware.backend - - for i in range(3): - for j in range(2): - backend.set_state(State(f"id{i * j}", composition_id=f"id{j}")) - - assert backend.get_states_count() == 2 - - def test_sort_with_offset(self, stub_broker, state_middleware): - if not isinstance(state_middleware.backend, PostgresBackend): - pytest.skip() - backend = state_middleware.backend - for i in range(8): - backend.set_state(State(f"id{i}", actor_name=f"{3 + 4 * (i // 4) - i % 4}")) - - res = backend.get_states(size=3, sort_column="actor_name", sort_direction="desc") - assert res[0].actor_name == "7" - assert res[1].actor_name == "6" - assert res[2].actor_name == "5" - res = backend.get_states(size=3, sort_column="actor_name", sort_direction="asc") - assert res[0].actor_name == "0" - assert res[1].actor_name == "1" - assert res[2].actor_name == "2" diff --git a/tests/state/test_postgres.py b/tests/state/test_postgres.py deleted file mode 100644 index 4f2acc8d2..000000000 --- a/tests/state/test_postgres.py +++ /dev/null @@ -1,65 +0,0 @@ -from remoulade.state.backends.postgres import DB_VERSION, StateVersion, StoredState -from tests.conftest import check_postgres - - -def test_no_changes(stub_broker, postgres_state_middleware, check_postgres_begin): - backend = postgres_state_middleware.backend - client = backend.client - with client.begin() as session: - session.add(StoredState(message_id="id")) - - backend.init_db() - assert check_postgres(client) - with client.begin() as session: - assert len(session.query(StoredState).all()) == 1 - - -def test_create_tables(stub_broker, postgres_state_middleware, check_postgres_begin): - backend = postgres_state_middleware.backend - client = backend.client - with client.begin() as session: - engine = session.get_bind() - StoredState.__table__.drop(bind=engine) - StateVersion.__table__.drop(bind=engine) - - backend.init_db() - assert check_postgres(client) - - -def test_change_version(stub_broker, postgres_state_middleware, check_postgres_begin): - backend = postgres_state_middleware.backend - client = backend.client - with client.begin() as session: - version = session.query(StateVersion).first() - version.version = DB_VERSION + 1 - session.add(StoredState(message_id="id")) - - backend.init_db() - assert check_postgres(client) - with client.begin() as session: - assert len(session.query(StoredState).all()) == 0 - - -def test_no_version(stub_broker, postgres_state_middleware, check_postgres_begin): - backend = postgres_state_middleware.backend - client = backend.client - with client.begin() as session: - engine = session.get_bind() - StateVersion.__table__.drop(bind=engine) - session.add(StoredState(message_id="id")) - - backend.init_db() - assert check_postgres(client) - with client.begin() as session: - assert len(session.query(StoredState).all()) == 0 - - -def test_no_states(stub_broker, postgres_state_middleware, check_postgres_begin): - backend = postgres_state_middleware.backend - client = backend.client - with client.begin() as session: - engine = session.get_bind() - StoredState.__table__.drop(bind=engine) - - backend.init_db() - assert check_postgres(client) diff --git a/tests/state/test_state_api.py b/tests/state/test_state_api.py index 0911a4bdf..39b319bc5 100644 --- a/tests/state/test_state_api.py +++ b/tests/state/test_state_api.py @@ -8,7 +8,6 @@ from zoneinfo import ZoneInfo import pytest -from dateutil.parser import parse import remoulade from remoulade import set_scheduler @@ -289,84 +288,6 @@ def do_work(): res = api_client.get(f"/messages/result/{message.message_id}") assert res.json["result"] == "non serializable result" - def test_select_actors(self, stub_broker, postgres_state_middleware): - backend = postgres_state_middleware.backend - for i in range(2): - backend.set_state(State(f"id{i}", actor_name=f"actor{i}")) - res = backend.get_states(selected_actors=["actor1"]) - assert len(res) == 1 - assert res[0].actor_name == "actor1" - - def test_select_statuses(self, stub_broker, postgres_state_middleware): - backend = postgres_state_middleware.backend - for i in range(2): - backend.set_state(State(f"id{i}", StateStatusesEnum.Success if i else StateStatusesEnum.Skipped)) - res = backend.get_states(selected_statuses=["Success"]) - assert len(res) == 1 - assert res[0].status == StateStatusesEnum.Success - - def test_select_ids(self, stub_broker, postgres_state_middleware): - backend = postgres_state_middleware.backend - for i in range(2): - backend.set_state(State(f"id{i}")) - res = backend.get_states(selected_message_ids=["id1"]) - assert len(res) == 1 - assert res[0].message_id == "id1" - - def test_select_datetimes(self, stub_broker, postgres_state_middleware): - backend = postgres_state_middleware.backend - for i in range(2): - backend.set_state(State(f"id{i}", enqueued_datetime=parse(f"2020-08-08 1{i}:00:00"))) - - res = backend.get_states(start_datetime=parse("2020-08-08 10:30:00")) - assert len(res) == 1 - assert res[0].message_id == "id1" - - res = backend.get_states(end_datetime=parse("2020-08-08 10:30:00")) - assert len(res) == 1 - assert res[0].message_id == "id0" - - def test_clean(self, stub_broker, postgres_state_middleware): - backend = postgres_state_middleware.backend - backend.set_state(State("id0")) - - assert len(backend.get_states()) == 1 - backend.clean() - assert len(backend.get_states()) == 0 - - def test_clean_max_age(self, stub_broker, postgres_state_middleware): - backend = postgres_state_middleware.backend - backend.set_state(State("id0", end_datetime=datetime.datetime.now(datetime.UTC))) - backend.set_state( - State("id1", end_datetime=datetime.datetime.now(datetime.UTC) - datetime.timedelta(minutes=50)) - ) - - assert len(backend.get_states()) == 2 - backend.clean(max_age=25) - res = backend.get_states() - assert len(res) == 1 - assert res[0].message_id == "id0" - - def test_clean_not_started(self, stub_broker, postgres_state_middleware): - backend = postgres_state_middleware.backend - backend.set_state(State("id0", started_datetime=datetime.datetime.now(datetime.UTC))) - backend.set_state(State("id1")) - - assert len(backend.get_states()) == 2 - backend.clean(not_started=True) - res = backend.get_states() - assert len(res) == 1 - assert res[0].message_id == "id0" - - def test_clean_route(self, stub_broker, postgres_state_middleware): - client = app.test_client() - backend = postgres_state_middleware.backend - backend.set_state(State("id0")) - - assert len(backend.get_states()) == 1 - client.delete("/messages/states") - assert len(backend.get_states()) == 0 - def test_cant_sort_by_args_kwargs_options(self, stub_broker, state_middleware, api_client): res = api_client.post( "/messages/states", data=json.dumps({"sort_column": "args"}), content_type="application/json" diff --git a/tests/test_actors.py b/tests/test_actors.py index 6d40acec2..57baef068 100644 --- a/tests/test_actors.py +++ b/tests/test_actors.py @@ -103,7 +103,7 @@ def add(x, y): # I expect it to enqueue a message enqueued_message = add.send(1, 2) enqueued_message_data = stub_broker.queues["default"].get(timeout=1) - assert enqueued_message == Message.decode(enqueued_message_data) + assert enqueued_message == Message.decode_bytes(enqueued_message_data) def test_actors_no_broker(): diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 76d0ba817..67035b4b9 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -45,6 +45,15 @@ def add_value(x): assert db == [1] +def test_json_encoder_json_round_trip(): + encoder = JSONEncoder() + data = {"queue_name": "default", "args": [1, 2], "kwargs": {"debug": True}} + + assert encoder.encode_in_json(data) == data + assert encoder.decode_json(data) == data + assert encoder.decode_bytes(encoder.encode_in_bytes(data)) == data + + class MyEnum(Enum): val = "val" other = "other" @@ -155,29 +164,46 @@ def encoder_with_fallback(stub_broker, stub_worker, result_backend) -> PydanticE def test_encoder_message(encoder: PydanticEncoder, message_data_decoded: MessageData, message_data_encoded: bytes): - encoded_result = encoder.encode(message_data_decoded) + encoded_result = encoder.encode_in_bytes(message_data_decoded) assert encoded_result == message_data_encoded - decoded_result = encoder.decode(message_data_encoded) + decoded_result = encoder.decode_bytes(message_data_encoded) # Args tuple are assumed to become list in remoulade assert decoded_result == tuple_to_list(message_data_decoded, "args") - assert encoder.decode(encoder.encode(message_data_decoded)) == tuple_to_list(message_data_decoded, "args") - assert encoder.encode(encoder.decode(message_data_encoded)) == message_data_encoded + assert encoder.decode_bytes(encoder.encode_in_bytes(message_data_decoded)) == tuple_to_list( + message_data_decoded, "args" + ) + assert encoder.encode_in_bytes(encoder.decode_bytes(message_data_encoded)) == message_data_encoded + + +def test_encoder_json_message_round_trip(pydantic_encoder, encoder: PydanticEncoder, message_data_decoded: MessageData): + message = Message(**message_data_decoded) + + encoded_json = message.encode_in_json() + + assert encoded_json["args"] == [json.loads(input_1.model_dump_json()) for input_1 in message.args] + assert encoded_json["kwargs"] == {"input_2": json.loads(message.kwargs["input_2"].model_dump_json())} + + decoded_message = Message.decode_json(encoded_json) + + assert decoded_message == message + assert isinstance(decoded_message.args[0], MyFirstSchema) + assert isinstance(decoded_message.kwargs["input_2"], MySecondSchema) def test_message_unknown_actor(encoder: PydanticEncoder, message_data_encoded: bytes): message_json_decoded = json.loads(message_data_encoded.decode("utf-8")) message_json_decoded["actor_name"] = "titi" with pytest.raises(ActorNotFound): - encoder.decode(json.dumps(message_json_decoded).encode("utf-8")) + encoder.decode_bytes(json.dumps(message_json_decoded).encode("utf-8")) def test_message_fallback_no_actor_name(encoder_with_fallback: PydanticEncoder, message_data_encoded: bytes): message_json_decoded = json.loads(message_data_encoded.decode("utf-8")) message_json_decoded["actor_name"] = "titi" - decoded_result = encoder_with_fallback.decode(json.dumps(message_json_decoded).encode("utf-8")) + decoded_result = encoder_with_fallback.decode_bytes(json.dumps(message_json_decoded).encode("utf-8")) # Do not raise and keep dict instead of schema assert decoded_result == tuple_to_list(message_json_decoded, "args") @@ -187,7 +213,7 @@ def test_message_schema_not_matching(encoder: PydanticEncoder, message_data_enco message_json_decoded["args"] = [{"toto": "a"}] message_json_decoded["kwargs"]["input_2"] = {"val": "aaa"} with pytest.raises(ValidationError): - encoder.decode(json.dumps(message_json_decoded).encode("utf-8")) + encoder.decode_bytes(json.dumps(message_json_decoded).encode("utf-8")) @pytest.fixture @@ -236,76 +262,76 @@ def backend_result_encoded_none() -> bytes: def test_encoder_result(encoder: PydanticEncoder, backend_result_decoded: MessageData, backend_result_encoded: bytes): - encoded_value = encoder.encode(backend_result_decoded) + encoded_value = encoder.encode_in_bytes(backend_result_decoded) assert encoded_value == backend_result_encoded - decoded_result = encoder.decode(backend_result_encoded) + decoded_result = encoder.decode_bytes(backend_result_encoded) assert decoded_result == backend_result_decoded - assert encoder.decode(encoder.encode(backend_result_decoded)) == backend_result_decoded - assert encoder.encode(encoder.decode(backend_result_encoded)) == backend_result_encoded + assert encoder.decode_bytes(encoder.encode_in_bytes(backend_result_decoded)) == backend_result_decoded + assert encoder.encode_in_bytes(encoder.decode_bytes(backend_result_encoded)) == backend_result_encoded def test_encoder_result_when_raise( encoder: PydanticEncoder, backend_result_decoded_raise: MessageData, backend_result_encoded_raise: bytes ): - encoded_value = encoder.encode(backend_result_decoded_raise) + encoded_value = encoder.encode_in_bytes(backend_result_decoded_raise) assert encoded_value == backend_result_encoded_raise - decoded_result = encoder.decode(backend_result_encoded_raise) + decoded_result = encoder.decode_bytes(backend_result_encoded_raise) assert decoded_result == backend_result_decoded_raise - assert encoder.decode(encoder.encode(backend_result_decoded_raise)) == backend_result_decoded_raise - assert encoder.encode(encoder.decode(backend_result_encoded_raise)) == backend_result_encoded_raise + assert encoder.decode_bytes(encoder.encode_in_bytes(backend_result_decoded_raise)) == backend_result_decoded_raise + assert encoder.encode_in_bytes(encoder.decode_bytes(backend_result_encoded_raise)) == backend_result_encoded_raise def test_encoder_result_tuple( encoder: PydanticEncoder, backend_result_decoded_tuple: MessageData, backend_result_encoded_tuple: bytes ): - encoded_value = encoder.encode(backend_result_decoded_tuple) + encoded_value = encoder.encode_in_bytes(backend_result_decoded_tuple) assert encoded_value == backend_result_encoded_tuple - decoded_result = encoder.decode(backend_result_encoded_tuple) + decoded_result = encoder.decode_bytes(backend_result_encoded_tuple) assert decoded_result == backend_result_decoded_tuple - assert encoder.decode(encoder.encode(backend_result_decoded_tuple)) == backend_result_decoded_tuple - assert encoder.encode(encoder.decode(backend_result_encoded_tuple)) == backend_result_encoded_tuple + assert encoder.decode_bytes(encoder.encode_in_bytes(backend_result_decoded_tuple)) == backend_result_decoded_tuple + assert encoder.encode_in_bytes(encoder.decode_bytes(backend_result_encoded_tuple)) == backend_result_encoded_tuple def test_encoder_result_with_none( encoder: PydanticEncoder, backend_result_decoded_none: MessageData, backend_result_encoded_none: bytes ): - encoded_value = encoder.encode(backend_result_decoded_none) + encoded_value = encoder.encode_in_bytes(backend_result_decoded_none) assert encoded_value == backend_result_encoded_none - decoded_result = encoder.decode(backend_result_encoded_none) + decoded_result = encoder.decode_bytes(backend_result_encoded_none) assert decoded_result == backend_result_decoded_none - assert encoder.decode(encoder.encode(backend_result_decoded_none)) == backend_result_decoded_none - assert encoder.encode(encoder.decode(backend_result_encoded_none)) == backend_result_encoded_none + assert encoder.decode_bytes(encoder.encode_in_bytes(backend_result_decoded_none)) == backend_result_decoded_none + assert encoder.encode_in_bytes(encoder.decode_bytes(backend_result_encoded_none)) == backend_result_encoded_none def test_backend_result_unknown_actor(encoder: PydanticEncoder, backend_result_encoded_tuple: bytes): backend_result_json_decoded = json.loads(backend_result_encoded_tuple.decode("utf-8")) backend_result_json_decoded["actor_name"] = "titi" with pytest.raises(ActorNotFound): - encoder.decode(json.dumps(backend_result_json_decoded).encode("utf-8")) + encoder.decode_bytes(json.dumps(backend_result_json_decoded).encode("utf-8")) def test_backend_result_schema_not_matching(encoder: PydanticEncoder, backend_result_encoded_tuple: bytes): backend_result_json_decoded = json.loads(backend_result_encoded_tuple.decode("utf-8")) backend_result_json_decoded["result"] = [{"val": "titi"}] with pytest.raises(ValidationError): - encoder.decode(json.dumps(backend_result_json_decoded).encode("utf-8")) + encoder.decode_bytes(json.dumps(backend_result_json_decoded).encode("utf-8")) def test_fallback_no_schema(encoder_with_fallback: PydanticEncoder, backend_result_encoded_tuple: bytes): backend_result_json_decoded = json.loads(backend_result_encoded_tuple.decode("utf-8")) backend_result_json_decoded["result"] = [{"val": "titi"}] - decoded_backend_result = encoder_with_fallback.decode(json.dumps(backend_result_json_decoded).encode("utf-8")) + decoded_backend_result = encoder_with_fallback.decode_bytes(json.dumps(backend_result_json_decoded).encode("utf-8")) # Do not raise and keep dict instead of schema assert decoded_backend_result == backend_result_json_decoded diff --git a/tests/test_postgres.py b/tests/test_postgres.py new file mode 100644 index 000000000..e6de40a36 --- /dev/null +++ b/tests/test_postgres.py @@ -0,0 +1,1095 @@ +import json +import logging +import os +import threading +import time +from unittest.mock import Mock + +import pytest +from pydantic import BaseModel +from sqlalchemy import JSON, Column, Integer, MetaData, Table, func +from sqlalchemy.sql import select, text + +import remoulade +from remoulade import Message, QueueJoinTimeout, UnsupportedMessageEncoding, Worker, group +from remoulade.brokers.postgres import PostgresBroker, _PostgresListener +from remoulade.encoder import Encoder, MessageData +from remoulade.results import Results +from remoulade.results.backends import StubBackend + +TEST_POSTGRES_URL = os.getenv("REMOULADE_TEST_DB_URL") or "postgresql://remoulade@localhost:5544/test" + + +def _count_messages(broker, queue_name="default"): + queue_table = Table(f"q_{queue_name}", MetaData(), Column("msg_id", Integer), schema="pgmq") + with broker.client.session() as session: + return session.execute(select(func.count()).select_from(queue_table)).scalar_one() + + +def _first_payload(broker, queue_name="default"): + queue_table = Table( + f"q_{queue_name}", + MetaData(), + Column("msg_id", Integer), + Column("message", JSON), + schema="pgmq", + ) + with broker.client.session() as session: + row = session.execute(select(queue_table.c.message).order_by(queue_table.c.msg_id).limit(1)).one() + return row[0] + + +def _count_archived_messages(broker, queue_name="default"): + archive_table = Table(f"a_{queue_name}", MetaData(), Column("msg_id", Integer), schema="pgmq") + with broker.client.session() as session: + return session.execute(select(func.count()).select_from(archive_table)).scalar_one() + + +def _queue_exists(broker, queue_name): + with broker.client.session() as session: + query = text("SELECT EXISTS(SELECT 1 FROM pgmq.list_queues() WHERE queue_name = :queue_name)") + return session.execute(query, {"queue_name": queue_name}).scalar_one() + + +def _expected_payload(message): + return json.loads(message.encode_in_bytes().decode("utf-8")) + + +class _StubTransaction: + """Stand-in for ``engine.begin()`` that yields a controllable connection. + + ``declare_queue`` always opens its own transaction and runs queue + operations against that connection, so tests inject a fake connection to + assert on the ``conn`` argument forwarded to the PGMQ client. + """ + + def __init__(self, connection): + self._connection = connection + + def __enter__(self): + return self._connection + + def __exit__(self, exc_type, exc, tb): + return None + + +class _FakeListener: + """Stand-in for the broker's shared LISTEN/NOTIFY listener. + + Lets tests control listener availability deterministically without opening + a real psycopg connection or starting the dispatch thread. + """ + + def __init__(self, available): + self.available = available + + def register(self, queue_name, event): + pass + + def unregister(self, queue_name, event): + pass + + def close(self): + pass + + +def _install_listener(broker, *, available): + """Replace the broker's shared listener with a controllable fake.""" + listener = _FakeListener(available) + broker._listener = listener + return listener + + +def test_postgres_broker_uses_provided_url(): + broker_url = TEST_POSTGRES_URL + broker = PostgresBroker(url=broker_url) + + assert broker.url == broker_url + + +def test_postgres_broker_creates_partitioned_queue_with_default_intervals(): + broker = PostgresBroker( + url=TEST_POSTGRES_URL, + middleware=[], + ) + broker.client.validate_queue_name = Mock() + broker.client.create_partitioned_queue = Mock() + broker.client.create_queue = Mock() + broker.client.enable_notify = Mock() + broker._queue_exists = Mock(return_value=False) + + conn = Mock() + broker.client.engine.begin = Mock(return_value=_StubTransaction(conn)) + + broker.declare_queue("default") + + broker.client.validate_queue_name.assert_called_once_with("default", conn=conn) + broker.client.create_partitioned_queue.assert_called_once_with( + "default", + partition_interval="1 day", + retention_interval="7 days", + conn=conn, + ) + broker.client.enable_notify.assert_called_once_with("default", throttle_interval_ms=250, conn=conn) + broker.client.create_queue.assert_not_called() + + +def test_postgres_broker_uses_current_transaction_connection_for_queue_creation(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[]) + broker.client.validate_queue_name = Mock() + broker.client.create_partitioned_queue = Mock() + broker.client.enable_notify = Mock() + broker._queue_exists = Mock(return_value=False) + + transaction_connection = Mock() + + broker.client.engine.begin = Mock(return_value=_StubTransaction(transaction_connection)) + + with broker.tx(): + broker.declare_queue("default") + + broker.client.validate_queue_name.assert_called_once_with("default", conn=transaction_connection) + broker.client.create_partitioned_queue.assert_called_once_with( + "default", + partition_interval="1 day", + retention_interval="7 days", + conn=transaction_connection, + ) + broker.client.enable_notify.assert_called_once_with( + "default", + throttle_interval_ms=250, + conn=transaction_connection, + ) + + +def test_postgres_broker_enables_notify_on_postgresql_queue_init(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[]) + broker.client.validate_queue_name = Mock() + broker.client.create_partitioned_queue = Mock() + broker.client.enable_notify = Mock() + broker._queue_exists = Mock(return_value=False) + + conn = Mock() + broker.client.engine.begin = Mock(return_value=_StubTransaction(conn)) + + broker.declare_queue("default") + + broker.client.create_partitioned_queue.assert_called_once_with( + "default", + partition_interval="1 day", + retention_interval="7 days", + conn=conn, + ) + broker.client.enable_notify.assert_called_once_with("default", throttle_interval_ms=250, conn=conn) + + +def test_postgres_broker_does_not_fail_when_enable_notify_raises(caplog): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[]) + broker.client.validate_queue_name = Mock() + broker.client.create_partitioned_queue = Mock() + broker.client.enable_notify = Mock(side_effect=RuntimeError("notify unavailable")) + broker._queue_exists = Mock(return_value=False) + + broker.declare_queue("default") + + assert "default" in broker.queues + assert "Failed to enable LISTEN/NOTIFY" in caplog.text + + +def test_postgres_broker_declare_queue_is_idempotent_when_queue_already_exists(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[]) + broker.client.validate_queue_name = Mock() + broker.client.create_partitioned_queue = Mock() + broker.client.enable_notify = Mock() + broker._queue_exists = Mock(return_value=True) + + broker.declare_queue("default") + + broker.client.create_partitioned_queue.assert_not_called() + broker.client.enable_notify.assert_not_called() + assert "default" in broker.queues + + +def test_postgres_broker_poll_only_mode_opens_no_listener_and_skips_enable_notify(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[], enable_listen_notify=False) + broker.client.validate_queue_name = Mock() + broker.client.create_partitioned_queue = Mock() + broker.client.enable_notify = Mock() + broker._queue_exists = Mock(return_value=False) + + broker.declare_queue("default") + + assert broker._listener is None + broker.client.enable_notify.assert_not_called() + + +def test_postgres_broker_poll_only_consumer_never_reports_listener_available(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[], enable_listen_notify=False) + broker.queues["default"] = None + + message = Message(queue_name="default", actor_name="do_work", args=(9,), kwargs={}, options={}) + broker.client.read = Mock(return_value=None) + broker.client.read_with_poll = Mock(return_value=[Mock(msg_id=9, message=_expected_payload(message))]) + + consumer = broker.consume("default", prefetch=1, timeout=200) + consumed = next(consumer) + + assert consumed is not None + assert consumer._listener_available is False + broker.client.read.assert_not_called() + broker.client.read_with_poll.assert_called_once() + consumer.close() + + +def test_postgres_consumer_rejects_negative_timeout(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[], enable_listen_notify=False) + broker.queues["default"] = None + + with pytest.raises(ValueError, match="timeout must be greater than or equal to 0"): + broker.consume("default", prefetch=1, timeout=-1) + + broker.close() + + +def test_postgres_consumer_rejects_prefetch_below_one(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[], enable_listen_notify=False) + broker.queues["default"] = None + + with pytest.raises(ValueError, match="prefetch must be greater than or equal to 1"): + broker.consume("default", prefetch=0, timeout=200) + + broker.close() + + +def test_postgres_poll_only_consumer_reads_immediately_when_timeout_is_zero(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[], enable_listen_notify=False) + broker.queues["default"] = None + broker.client.read = Mock(return_value=None) + broker.client.read_with_poll = Mock(return_value=None) + + consumer = broker.consume("default", prefetch=1, timeout=0) + + # A zero timeout is non-blocking: the polling fallback must do a single + # immediate read and return None at once instead of polling for ~1s. + assert next(consumer) is None + broker.client.read.assert_called_once() + broker.client.read_with_poll.assert_not_called() + consumer.close() + + +def test_postgres_broker_shares_a_single_listener_across_consumers(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[]) + broker.queues["first"] = None + broker.queues["second"] = None + listener = _install_listener(broker, available=True) + listener.register = Mock() + listener.unregister = Mock() + + first = broker.consume("first", prefetch=1, timeout=0) + second = broker.consume("second", prefetch=1, timeout=0) + + assert first.broker._listener is second.broker._listener + assert listener.register.call_count == 2 + registered_queues = {call.args[0] for call in listener.register.call_args_list} + assert registered_queues == {"first", "second"} + + first.close() + second.close() + assert listener.unregister.call_count == 2 + + +def test_postgres_broker_forwards_pool_size_to_client(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[], pool_size=3) + + assert broker.client.engine.pool.size() == 3 + + +def test_postgres_broker_uses_custom_partition_settings_when_provided(): + broker = PostgresBroker( + url=TEST_POSTGRES_URL, + middleware=[], + archive_partition_interval_in_days=2, + archive_retention_interval_in_days=14, + ) + broker.client.validate_queue_name = Mock() + broker.client.create_partitioned_queue = Mock() + broker.client.enable_notify = Mock() + broker._queue_exists = Mock(return_value=False) + + conn = Mock() + broker.client.engine.begin = Mock(return_value=_StubTransaction(conn)) + + broker.declare_queue("default") + + broker.client.create_partitioned_queue.assert_called_once_with( + "default", + partition_interval="2 days", + retention_interval="14 days", + conn=conn, + ) + + +def test_postgres_broker_rejects_non_json_message_encoders(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[]) + + class _MessageWithInvalidJson(Encoder): + def _encode_in_json(self, data): + raise TypeError("not json") + + def decode_json(self, data): + return data + + def encode_in_bytes(self, data: MessageData) -> bytes: + return b"" + + def decode_bytes(self, data: bytes) -> MessageData: + return {} + + with pytest.raises(UnsupportedMessageEncoding): + broker._encode_message(_MessageWithInvalidJson()) + + +def test_postgres_broker_rejects_nested_non_json_safe_payloads(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[]) + old_encoder = remoulade.get_encoder() + + class _NestedInvalidJsonEncoder(Encoder): + def _encode_in_json(self, data): + return {**data, "options": {"nested": {1}}} + + def decode_json(self, data): + return data + + def encode_in_bytes(self, data: MessageData) -> bytes: + return b"" + + def decode_bytes(self, data: bytes) -> MessageData: + return {} + + remoulade.set_encoder(_NestedInvalidJsonEncoder()) + try: + message = Message(queue_name="default", actor_name="do_work", args=(), kwargs={}, options={}) + + with pytest.raises(UnsupportedMessageEncoding): + broker._encode_message(message) + finally: + remoulade.set_encoder(old_encoder) + + +@pytest.mark.usefixtures("postgres_broker") +def test_postgres_broker_enqueue_stores_a_standard_remoulade_payload_as_jsonb(postgres_broker): + message = Message(queue_name="default", actor_name="do_work", args=(1, 2), kwargs={"debug": True}, options={}) + postgres_broker.declare_queue(message.queue_name) + + postgres_broker.enqueue(message) + + assert _count_messages(postgres_broker) == 1 + assert _first_payload(postgres_broker) == _expected_payload(message) + + +@pytest.mark.usefixtures("postgres_broker") +def test_postgres_broker_uses_native_visibility_delay_without_delay_queue(postgres_broker): + message = Message(queue_name="default", actor_name="do_work", args=(), kwargs={}, options={}) + postgres_broker.declare_queue(message.queue_name) + + postgres_broker.enqueue(message, delay=250) + + assert postgres_broker.client.read("default", vt=1) is None + + time.sleep(0.35) + delayed = postgres_broker.client.read("default", vt=1) + + assert delayed is not None + assert delayed.message == _expected_payload(message) + + +@pytest.mark.usefixtures("postgres_broker") +def test_postgres_broker_transactions_commit_and_rollback_messages(postgres_broker): + @remoulade.actor + def do_work(): + return 1 + + postgres_broker.declare_actor(do_work) + + with postgres_broker.tx(): + do_work.send() + + with pytest.raises(ValueError), postgres_broker.tx(): + do_work.send() + raise ValueError("rollback") + + assert _count_messages(postgres_broker) == 1 + + +def test_postgres_consumer_uses_notification_path_when_listener_is_available(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[]) + broker.queues["default"] = None + _install_listener(broker, available=True) + + message = Message(queue_name="default", actor_name="do_work", args=(1,), kwargs={}, options={}) + payload = _expected_payload(message) + broker.client.read = Mock(return_value=Mock(msg_id=1, message=payload)) + broker.client.read_with_poll = Mock(return_value=[]) + + consumer = broker.consume("default", prefetch=1, timeout=200) + consumer._notify_event.wait = Mock(return_value=False) + + consumed = next(consumer) + + assert consumed is not None + assert consumed.message_id == message.message_id + broker.client.read.assert_called_once_with("default", vt=30, qty=1) + broker.client.read_with_poll.assert_not_called() + consumer._notify_event.wait.assert_not_called() + assert consumer._listener_available is True + consumer.close() + + +def test_postgres_consumer_falls_back_to_polling_when_listener_is_unavailable(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[]) + broker.queues["default"] = None + _install_listener(broker, available=False) + + message = Message(queue_name="default", actor_name="do_work", args=(2,), kwargs={}, options={}) + payload = _expected_payload(message) + broker.client.read = Mock(return_value=None) + broker.client.read_with_poll = Mock(return_value=[Mock(msg_id=2, message=payload)]) + + consumer = broker.consume("default", prefetch=1, timeout=200) + consumed = next(consumer) + + assert consumed is not None + assert consumed.message_id == message.message_id + broker.client.read.assert_not_called() + broker.client.read_with_poll.assert_called_once_with( + "default", + vt=30, + qty=1, + max_poll_seconds=1, + poll_interval_ms=200, + ) + consumer.close() + + +def test_postgres_consumer_keeps_listener_path_after_an_empty_cycle(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[]) + broker.queues["default"] = None + _install_listener(broker, available=True) + + broker.client.read = Mock(return_value=None) + broker.client.read_with_poll = Mock(return_value=[]) + + consumer = broker.consume("default", prefetch=1, timeout=200) + consumer._notify_event.wait = Mock(return_value=False) + + consumed = next(consumer) + + assert consumed is None + assert consumer._listener_available is True + assert broker.client.read.call_count == 2 + broker.client.read_with_poll.assert_not_called() + consumer._notify_event.wait.assert_called_once_with(0.2) + consumer.close() + + +def test_postgres_consumer_uses_broker_visibility_timeout_for_reads(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[], visibility_timeout_ms=17_000) + broker.queues["default"] = None + _install_listener(broker, available=False) + + message = Message(queue_name="default", actor_name="do_work", args=(5,), kwargs={}, options={}) + payload = _expected_payload(message) + broker.client.read = Mock(return_value=None) + broker.client.read_with_poll = Mock(return_value=[Mock(msg_id=5, message=payload)]) + + consumer = broker.consume("default", prefetch=1, timeout=200) + consumed = next(consumer) + + assert consumed is not None + broker.client.read_with_poll.assert_called_once_with( + "default", + vt=17, + qty=1, + max_poll_seconds=1, + poll_interval_ms=200, + ) + consumer.close() + + +def test_postgres_consumer_tracks_all_prefetched_messages_for_heartbeat(monkeypatch): + def _fake_start_heartbeat(self): + self._heartbeat_thread = None + + monkeypatch.setattr("remoulade.brokers.postgres._PostgresConsumer._start_heartbeat", _fake_start_heartbeat) + + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[]) + broker.queues["default"] = None + _install_listener(broker, available=False) + + first_message = Message(queue_name="default", actor_name="do_work", args=(1,), kwargs={}, options={}) + second_message = Message(queue_name="default", actor_name="do_work", args=(2,), kwargs={}, options={}) + broker.client.read = Mock(return_value=None) + broker.client.read_with_poll = Mock( + return_value=[ + Mock(msg_id=1, message=_expected_payload(first_message)), + Mock(msg_id=2, message=_expected_payload(second_message)), + ] + ) + broker.client.set_vt = Mock() + + consumer = broker.consume("default", prefetch=2, timeout=200) + consumed = next(consumer) + + assert consumed is not None + assert consumed.message_id == first_message.message_id + with consumer._heartbeat_message_ids_lock: + assert consumer._heartbeat_message_ids == {1, 2} + + consumer.close() + + +def test_postgres_consumer_close_requeues_buffered_prefetched_messages(monkeypatch): + def _fake_start_heartbeat(self): + self._heartbeat_thread = None + + monkeypatch.setattr("remoulade.brokers.postgres._PostgresConsumer._start_heartbeat", _fake_start_heartbeat) + + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[]) + broker.queues["default"] = None + _install_listener(broker, available=False) + + first_message = Message(queue_name="default", actor_name="do_work", args=(1,), kwargs={}, options={}) + second_message = Message(queue_name="default", actor_name="do_work", args=(2,), kwargs={}, options={}) + broker.client.read = Mock(return_value=None) + broker.client.read_with_poll = Mock( + return_value=[ + Mock(msg_id=1, message=_expected_payload(first_message)), + Mock(msg_id=2, message=_expected_payload(second_message)), + ] + ) + broker.client.set_vt = Mock() + + consumer = broker.consume("default", prefetch=2, timeout=200) + consumed = next(consumer) + + assert consumed is not None + assert consumed.message_id == first_message.message_id + + consumer.close() + + broker.client.set_vt.assert_called_once_with("default", [2], 0) + + +def test_postgres_consumer_falls_back_to_polling_when_listener_stops_during_wait(): + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[]) + broker.queues["default"] = None + listener = _install_listener(broker, available=True) + + message = Message(queue_name="default", actor_name="do_work", args=(3,), kwargs={}, options={}) + payload = _expected_payload(message) + broker.client.read = Mock(return_value=None) + broker.client.read_with_poll = Mock(return_value=[Mock(msg_id=3, message=payload)]) + + consumer = broker.consume("default", prefetch=1, timeout=200) + + def _wait_and_stop(_timeout): + listener.available = False + return False + + consumer._notify_event.wait = Mock(side_effect=_wait_and_stop) + + consumed = next(consumer) + + assert consumed is not None + assert consumed.message_id == message.message_id + broker.client.read.assert_called_once_with("default", vt=30, qty=1) + broker.client.read_with_poll.assert_called_once_with( + "default", + vt=30, + qty=1, + max_poll_seconds=1, + poll_interval_ms=200, + ) + consumer.close() + + +def test_postgres_consumer_heartbeat_extends_inflight_message_visibility(): + broker = PostgresBroker( + url=TEST_POSTGRES_URL, + middleware=[], + visibility_timeout_ms=2_000, + heartbeat_interval_ms=50, + ) + broker.queues["default"] = None + _install_listener(broker, available=False) + + message = Message(queue_name="default", actor_name="do_work", args=(7,), kwargs={}, options={}) + payload = _expected_payload(message) + broker.client.read = Mock(return_value=None) + broker.client.read_with_poll = Mock(return_value=[Mock(msg_id=7, message=payload)]) + broker.client.set_vt = Mock() + broker.client.archive = Mock() + + consumer = broker.consume("default", prefetch=1, timeout=200) + consumed = next(consumer) + assert consumed is not None + + deadline = time.monotonic() + 1.0 + while broker.client.set_vt.call_count == 0 and time.monotonic() < deadline: + time.sleep(0.01) + + assert broker.client.set_vt.call_count >= 1 + queue_name, msg_ids, vt = broker.client.set_vt.call_args.args + assert queue_name == "default" + assert msg_ids == [7] + assert vt == 2 + + consumer.ack(consumed) + consumer.close() + + +def test_postgres_consumer_decodes_payload_with_global_encoder(pydantic_encoder): + class InputSchema(BaseModel): + value: int + + broker = PostgresBroker(url=TEST_POSTGRES_URL, middleware=[]) + remoulade.set_broker(broker) + broker.client.validate_queue_name = Mock() + broker.client.create_partitioned_queue = Mock() + broker.client.enable_notify = Mock() + broker._queue_exists = Mock(return_value=False) + _install_listener(broker, available=False) + + @remoulade.actor(actor_name="typed.actor", queue_name="default") + def typed_actor(payload: InputSchema): + return payload.value + + broker.declare_actor(typed_actor) + + message = Message( + queue_name="default", + actor_name="typed.actor", + args=(InputSchema(value=42),), + kwargs={}, + options={}, + ) + payload = _expected_payload(message) + broker.client.read = Mock(return_value=None) + broker.client.read_with_poll = Mock(return_value=[Mock(msg_id=1, message=payload)]) + + consumer = broker.consume("default", prefetch=1, timeout=200) + consumed = next(consumer) + + assert consumed is not None + assert isinstance(consumed.args[0], InputSchema) + assert consumed.args[0].value == 42 + consumer.close() + + +@pytest.mark.usefixtures("postgres_broker") +def test_postgres_consumer_reads_messages_and_acks_with_delete(postgres_broker): + message = Message(queue_name="default", actor_name="do_work", args=(42,), kwargs={}, options={}) + postgres_broker.declare_queue(message.queue_name) + postgres_broker.enqueue(message) + + consumer = postgres_broker.consume("default", prefetch=2, timeout=200) + consumed_message = next(consumer) + assert consumed_message is not None + assert consumed_message.message_id == message.message_id + + consumer.ack(consumed_message) + consumer.close() + + assert _count_messages(postgres_broker) == 0 + + +@pytest.mark.usefixtures("postgres_broker") +def test_postgres_consumer_nack_archives_messages(postgres_broker): + message = Message(queue_name="default", actor_name="do_work", args=(), kwargs={}, options={}) + postgres_broker.declare_queue(message.queue_name) + postgres_broker.enqueue(message) + + consumer = postgres_broker.consume("default", prefetch=1, timeout=200) + consumed_message = next(consumer) + + assert consumed_message is not None + consumer.nack(consumed_message) + consumer.close() + + assert _count_messages(postgres_broker) == 0 + assert _count_archived_messages(postgres_broker) == 1 + + +@pytest.mark.usefixtures("postgres_broker") +def test_postgres_consumer_requeue_restores_visibility_with_set_vt(postgres_broker): + message = Message(queue_name="default", actor_name="do_work", args=(), kwargs={}, options={}) + postgres_broker.declare_queue(message.queue_name) + postgres_broker.enqueue(message) + + consumer = postgres_broker.consume("default", prefetch=1, timeout=200) + consumed_message = next(consumer) + + assert consumed_message is not None + consumer.requeue([consumed_message]) + + replayed_message = next(consumer) + assert replayed_message is not None + assert replayed_message.message_id == message.message_id + + consumer.ack(replayed_message) + consumer.close() + + assert _count_messages(postgres_broker) == 0 + + +@pytest.mark.usefixtures("postgres_broker") +def test_postgres_worker_processes_native_delayed_messages_without_delay_queue(postgres_broker): + seen = [] + + @remoulade.actor + def do_work(value): + seen.append(value) + + postgres_broker.declare_actor(do_work) + worker = Worker(postgres_broker, worker_timeout=100, worker_threads=2) + worker.start() + try: + do_work.send_with_options(args=(3,), delay=150) + postgres_broker.join(do_work.queue_name, timeout=10_000) + assert seen == [3] + worker.join() + finally: + worker.stop() + + +@pytest.mark.usefixtures("postgres_broker") +def test_postgres_broker_join_times_out_while_processing_invisible_message(postgres_broker): + started = threading.Event() + release = threading.Event() + + @remoulade.actor + def do_work(): + started.set() + release.wait(timeout=5) + + postgres_broker.declare_actor(do_work) + + worker = Worker(postgres_broker, worker_timeout=100, worker_threads=1) + worker.start() + try: + do_work.send() + assert started.wait(timeout=2) + + with pytest.raises(QueueJoinTimeout): + postgres_broker.join(do_work.queue_name, timeout=100) + + release.set() + postgres_broker.join(do_work.queue_name, timeout=5_000) + worker.join() + finally: + release.set() + worker.stop() + + +@pytest.mark.usefixtures("postgres_broker") +def test_postgres_worker_processes_a_two_actor_pipeline(postgres_broker): + seen: list[tuple[str, int]] = [] + + @remoulade.actor + def first_actor(value): + seen.append(("first", value)) + return value + 1 + + @remoulade.actor + def second_actor(value): + seen.append(("second", value)) + + postgres_broker.declare_actor(first_actor) + postgres_broker.declare_actor(second_actor) + + worker = Worker(postgres_broker, worker_timeout=100, worker_threads=1) + worker.start() + try: + remoulade.pipeline([first_actor.message(1), second_actor.message()]).run() + + postgres_broker.join(second_actor.queue_name, timeout=10_000) + worker.join() + + assert seen == [("first", 1), ("second", 2)] + finally: + worker.stop() + + +@pytest.mark.usefixtures("postgres_broker") +def test_postgres_consumer_listener_wakes_on_enqueue_with_listen_notify(postgres_broker): + message = Message(queue_name="default", actor_name="do_work", args=(99,), kwargs={}, options={}) + postgres_broker.declare_queue(message.queue_name) + consumer = postgres_broker.consume(message.queue_name, prefetch=1, timeout=1500) + + if not consumer._listener_available: + pytest.skip("LISTEN/NOTIFY listener unavailable in this environment.") + + consumed_messages = [] + + def _consume_once(): + consumed_messages.append(next(consumer)) + + thread = threading.Thread(target=_consume_once) + thread.start() + try: + time.sleep(0.15) + postgres_broker.enqueue(message) + thread.join(timeout=3) + assert not thread.is_alive() + assert consumed_messages + consumed = consumed_messages[0] + assert consumed is not None + assert consumed.message_id == message.message_id + consumer.ack(consumed) + finally: + consumer.close() + + +class _FakeConnection: + """Minimal psycopg connection stand-in for listener reconnection tests.""" + + def __init__(self, *, raise_on_notify=False, on_notify=None): + self.raise_on_notify = raise_on_notify + self.on_notify = on_notify + self.listened = [] + self.closed = False + + def execute(self, statement): + self.listened.append(statement) + + def notifies(self, timeout=0.5, stop_after=1): + if self.raise_on_notify: + self.raise_on_notify = False + raise RuntimeError("connection lost") + if self.on_notify is not None: + self.on_notify() + return [] + + def close(self): + self.closed = True + + +def test_listener_open_connection_relistens_every_registered_channel(monkeypatch): + listener = _PostgresListener("postgresql://localhost/test", logging.getLogger("test")) + listener._channel_to_queue = { + listener._channel_for("first"): "first", + listener._channel_for("second"): "second", + } + connection = _FakeConnection() + monkeypatch.setattr("remoulade.brokers.postgres.psycopg.connect", Mock(return_value=connection)) + + assert listener._open_connection() is True + assert listener.available is True + assert listener._connection is connection + assert len(connection.listened) == 2 + + +def test_listener_open_connection_failure_keeps_listener_unavailable(monkeypatch, caplog): + listener = _PostgresListener("postgresql://localhost/test", logging.getLogger("test")) + monkeypatch.setattr("remoulade.brokers.postgres.psycopg.connect", Mock(side_effect=OSError("database is down"))) + + with caplog.at_level(logging.WARNING): + assert listener._open_connection() is False + + assert listener.available is False + assert listener._connection is None + assert "Failed to open shared LISTEN/NOTIFY connection" in caplog.text + + +def test_listener_reconnects_after_a_connection_drop(monkeypatch): + monkeypatch.setattr("remoulade.brokers.postgres.LISTENER_RECONNECT_BACKOFF_MIN_S", 0.01) + monkeypatch.setattr("remoulade.brokers.postgres.LISTENER_RECONNECT_BACKOFF_MAX_S", 0.01) + + recovered = threading.Event() + healthy_connection = _FakeConnection(on_notify=recovered.set) + connections = [_FakeConnection(raise_on_notify=True), healthy_connection] + connect = Mock(side_effect=connections) + monkeypatch.setattr("remoulade.brokers.postgres.psycopg.connect", connect) + + listener = _PostgresListener("postgresql://localhost/test", logging.getLogger("test")) + listener.register("default", threading.Event()) + try: + assert recovered.wait(timeout=2) + assert listener.available is True + assert connect.call_count >= 2 + assert connections[0].closed is True + assert len(healthy_connection.listened) == 1 + finally: + listener.close() + + +def test_listener_recovers_when_initial_connection_fails(monkeypatch): + monkeypatch.setattr("remoulade.brokers.postgres.LISTENER_RECONNECT_BACKOFF_MIN_S", 0.01) + monkeypatch.setattr("remoulade.brokers.postgres.LISTENER_RECONNECT_BACKOFF_MAX_S", 0.01) + + recovered = threading.Event() + healthy_connection = _FakeConnection(on_notify=recovered.set) + connect = Mock(side_effect=[OSError("database is down"), healthy_connection]) + monkeypatch.setattr("remoulade.brokers.postgres.psycopg.connect", connect) + + listener = _PostgresListener("postgresql://localhost/test", logging.getLogger("test")) + listener.register("default", threading.Event()) + try: + assert recovered.wait(timeout=2) + assert listener.available is True + assert connect.call_count >= 2 + finally: + listener.close() + + +def test_listener_close_stops_the_dispatch_thread(monkeypatch): + connection = _FakeConnection() + monkeypatch.setattr("remoulade.brokers.postgres.psycopg.connect", Mock(return_value=connection)) + + listener = _PostgresListener("postgresql://localhost/test", logging.getLogger("test")) + listener.register("default", threading.Event()) + + listener.close() + + assert listener.available is False + assert connection.closed is True + assert listener._thread is not None + assert listener._thread.is_alive() is False + + +# End-to-end tests running a real Worker against PGMQ, exercising the +# middleware-driven behaviours (retries, results, groups) that the unit tests +# above stub out. These validate the broker's ack/nack/requeue/heartbeat +# interplay through the full message lifecycle, not just in isolation. + + +@pytest.mark.usefixtures("postgres_broker") +def test_postgres_worker_retries_failed_message_then_succeeds(postgres_broker): + attempts = [] + + @remoulade.actor(max_retries=3, min_backoff=50, max_backoff=50, jitter=False) + def flaky(): + attempts.append(1) + if len(attempts) < 2: + raise RuntimeError("boom") + + postgres_broker.declare_actor(flaky) + + worker = Worker(postgres_broker, worker_timeout=100, worker_threads=1) + worker.start() + try: + flaky.send() + postgres_broker.join(flaky.queue_name, timeout=10_000) + worker.join() + + # The actor ran twice: the failed attempt, then the retry that succeeded. + assert len(attempts) == 2 + # The failed attempt is archived (acked) and a delayed retry enqueued; + # the successful retry is archived too. Nothing is left in the queue. + assert _count_messages(postgres_broker) == 0 + assert _count_archived_messages(postgres_broker) == 2 + finally: + worker.stop() + + +@pytest.mark.usefixtures("postgres_broker") +def test_postgres_worker_archives_message_when_retries_exhausted(postgres_broker): + attempts = [] + + @remoulade.actor(max_retries=1, min_backoff=50, max_backoff=50, jitter=False) + def always_fails(): + attempts.append(1) + raise RuntimeError("boom") + + postgres_broker.declare_actor(always_fails) + + worker = Worker(postgres_broker, worker_timeout=100, worker_threads=1) + worker.start() + try: + always_fails.send() + postgres_broker.join(always_fails.queue_name, timeout=10_000) + worker.join() + + # max_retries=1 -> one initial attempt plus one retry, then it is failed. + assert len(attempts) == 2 + # Both the acked retry-source and the nacked exhausted message end up + # archived; an empty queue proves the nack does not leave the message + # invisible to be redelivered forever (PostgresBroker has no DLQ). + assert _count_messages(postgres_broker) == 0 + assert _count_archived_messages(postgres_broker) == 2 + finally: + worker.stop() + + +@pytest.mark.usefixtures("postgres_broker") +def test_postgres_worker_stores_and_retrieves_actor_result(postgres_broker): + postgres_broker.add_middleware(Results(backend=StubBackend())) + + @remoulade.actor(store_results=True) + def do_work(): + return 42 + + postgres_broker.declare_actor(do_work) + + worker = Worker(postgres_broker, worker_timeout=100, worker_threads=1) + worker.start() + try: + message = do_work.send() + postgres_broker.join(do_work.queue_name, timeout=10_000) + worker.join() + + assert message.result.get(block=True) == 42 + finally: + worker.stop() + + +@pytest.mark.usefixtures("postgres_broker") +def test_postgres_worker_runs_group_with_results(postgres_broker): + postgres_broker.add_middleware(Results(backend=StubBackend())) + + @remoulade.actor(store_results=True) + def square(value): + return value * value + + postgres_broker.declare_actor(square) + + worker = Worker(postgres_broker, worker_timeout=100, worker_threads=4) + worker.start() + try: + g = group([square.message(value) for value in range(4)]) + g.run() + postgres_broker.join(square.queue_name, timeout=10_000) + worker.join() + + assert sorted(g.results.get(block=True)) == [0, 1, 4, 9] + finally: + worker.stop() + + +def test_listener_wakes_all_consumers_on_a_queue_and_unregister_spares_siblings(): + listener = _PostgresListener("postgresql://localhost/test", logging.getLogger("test")) + # Skip starting the dispatch thread: this test only exercises wake routing. + listener._started = True + + first, second = threading.Event(), threading.Event() + listener.register("default", first) + listener.register("default", second) + channel = listener._channel_for("default") + + # A notification on the shared queue wakes every registered consumer. + listener._wake_channel(channel) + assert first.is_set() + assert second.is_set() + first.clear() + second.clear() + + # One consumer leaving must not stop notifications for its sibling. + listener.unregister("default", first) + listener._wake_channel(channel) + assert not first.is_set() + assert second.is_set() + + # The channel routing is dropped only once the last consumer leaves. + listener.unregister("default", second) + assert channel not in listener._channel_to_queue + assert "default" not in listener._events diff --git a/tests/test_results.py b/tests/test_results.py index f45cac153..0d82030f3 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -534,7 +534,7 @@ def test_redis_get_result_with_block_timeout_larger_than_socket_timeout(redis_re redis.TimeoutError(), redis.TimeoutError(), redis.TimeoutError(), - redis_result_backend.encoder.encode(ForgottenResult.asdict()), + redis_result_backend.encoder.encode_in_bytes(ForgottenResult.asdict()), ] redis_result_backend.max_retries = 1 redis_result_backend.get_result("message-id", block=True, forget=False, timeout=60 * 1000) @@ -545,7 +545,7 @@ def test_redis_get_result_with_block_timeout_larger_than_socket_timeout(redis_re @mock.patch("remoulade.results.backends.redis.compute_backoff", fast_backoff) def test_redis_get_result_still_return_result_if_forget_fails(redis_result_backend): with patch.object(redis_result_backend, "client") as mock_client: - mock_client.brpoplpush.return_value = redis_result_backend.encoder.encode(ForgottenResult.asdict()) + mock_client.brpoplpush.return_value = redis_result_backend.encoder.encode_in_bytes(ForgottenResult.asdict()) mock_client.pipeline.side_effect = redis.ConnectionError() assert redis_result_backend.get_result("message-id", block=True, forget=True) is None assert mock_client.brpoplpush.call_count == 1 @@ -613,7 +613,7 @@ async def test_redis_async_get_result_with_block_timeout_larger_than_socket_time redis.TimeoutError(), redis.TimeoutError(), redis.TimeoutError(), - redis_result_backend.encoder.encode(ForgottenResult.asdict()), + redis_result_backend.encoder.encode_in_bytes(ForgottenResult.asdict()), ] redis_result_backend.max_retries = 1 await redis_result_backend.async_get_result("message-id", forget=False, timeout=60 * 1000) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 2ebea8e75..e0d4abe45 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1,5 +1,4 @@ import datetime -import json import threading import time from zoneinfo import ZoneInfo @@ -260,143 +259,6 @@ def test_delete_job(scheduler): assert scheduler.get_redis_schedule() == {} -def test_get_scheduled_jobs(scheduler, api_client): - scheduler.schedule = [ScheduledJob(actor_name="do_work")] - scheduler.sync_config() - res = api_client.get("/scheduled/jobs") - assert res.status_code == 200 - assert res.json == { - "result": [ - { - "hash": scheduler.schedule[0].get_hash(), - "actor_name": "do_work", - "args": [], - "daily_time": None, - "enabled": True, - "interval": 86400, - "iso_weekday": None, - "kwargs": {}, - "last_queued": None, - "tz": "UTC", - } - ] - } - - -def test_api_add_job(scheduler, api_client, do_work): - res = api_client.post( - "/scheduled/jobs", data=json.dumps({"actor_name": "do_work"}), content_type="application/json" - ) - assert res.status_code == 200 - assert res.json == { - "result": [ - { - "hash": res.json["result"][0]["hash"], - "actor_name": "do_work", - "args": [], - "daily_time": None, - "enabled": True, - "interval": 86400, - "iso_weekday": None, - "kwargs": {}, - "last_queued": None, - "tz": "UTC", - } - ] - } - - -def test_api_delete_job(scheduler, api_client): - scheduler.schedule = [ScheduledJob(actor_name="do_work")] - scheduler.sync_config() - res = api_client.delete(f"/scheduled/jobs/{scheduler.schedule[0].get_hash()}") - assert res.status_code == 200 - assert res.json == {"result": []} - - -def test_update_job(scheduler, api_client, do_work): - scheduler.schedule = [ScheduledJob(actor_name="do_work")] - scheduler.sync_config() - res = api_client.put( - f"/scheduled/jobs/{scheduler.schedule[0].get_hash()}", - data=json.dumps({"actor_name": "do_work", "enabled": False}), - content_type="application/json", - ) - assert res.status_code == 200 - assert res.json == { - "result": [ - { - "hash": res.json["result"][0]["hash"], - "actor_name": "do_work", - "args": [], - "daily_time": None, - "enabled": False, - "interval": 86400, - "iso_weekday": None, - "kwargs": {}, - "last_queued": None, - "tz": "UTC", - } - ] - } - - -def test_api_update_jobs(scheduler, api_client, do_work): - scheduler.schedule = [ScheduledJob(actor_name="do_work"), ScheduledJob(actor_name="do_other_work")] - scheduler.sync_config() - res = api_client.put( - "/scheduled/jobs", - data=json.dumps( - { - "jobs": { - scheduler.schedule[0].get_hash(): {"actor_name": "do_work", "enabled": False}, - scheduler.schedule[1].get_hash(): {"actor_name": "do_work", "enabled": False, "interval": "55"}, - } - } - ), - content_type="application/json", - ) - assert res.status_code == 200 - assert len(res.json["result"]) == 2 - for i in range(2): - assert not res.json["result"][i]["enabled"] - - -def test_daily_time_wrong_interval(scheduler, api_client): - res = api_client.post( - "/scheduled/jobs", - data=json.dumps({"actor_name": "do_work", "daily_time": "00:00", "interval": 1000}), - content_type="application/json", - ) - assert res.status_code == 400 - - -def test_invalid_timezone(scheduler, api_client): - res = api_client.post( - "/scheduled/jobs", - data=json.dumps({"actor_name": "do_work", "tz": "invalid_tz"}), - content_type="application/json", - ) - assert res.status_code == 400 - - -def test_invalid_actor_name(scheduler, api_client): - res = api_client.post( - "/scheduled/jobs", data=json.dumps({"actor_name": "invalid_actor"}), content_type="application/json" - ) - assert res.status_code == 400 - - -def test_tz_aware_last_queued(scheduler, api_client, do_work): - res = api_client.post( - "scheduled/jobs", - data=json.dumps({"actor_name": "do_work", "last_queued": "2020-10-10 10:00:00Z"}), - content_type="application/json", - ) - - assert res.status_code == 400 - - class InputArg(BaseModel): data: str