diff --git a/examples/cases/stream/subscribe_to_rpc_stream.py b/examples/cases/stream/subscribe_to_rpc_stream.py new file mode 100644 index 0000000..300a68b --- /dev/null +++ b/examples/cases/stream/subscribe_to_rpc_stream.py @@ -0,0 +1,63 @@ +import asyncio +import logging +from asyncio import run +from signal import SIGINT, SIGTERM + +from examples.utils import BTC_USD_MARKET, create_stream_rpc_client, init_env +from x10.clients.streamrpc.subscription_params import ( + CandlesParams, + PricesParams, + TradesParams, +) +from x10.config import get_config_by_name +from x10.models.stream_rpc import StreamRpcResponseModel + +LOGGER = logging.getLogger() +MARKET_NAME = BTC_USD_MARKET + + +def on_message(message: StreamRpcResponseModel) -> None: + LOGGER.info("Received message: %s", message) + + +async def subscribe_to_rpc_stream(stop_event: asyncio.Event): + env_config = init_env() + client_config = get_config_by_name(env_config.client_config_name) + + async with create_stream_rpc_client(client_config) as client: + await client.ping() + + subscriptions_before = await client.list_subscriptions() + + LOGGER.info("Active subscriptions: %s", subscriptions_before) + + await client.subscribe(params=TradesParams(market="BTC-USD"), handler=on_message) + await client.subscribe(params=TradesParams(market="ETH-USD"), handler=on_message) + await client.subscribe(params=PricesParams(price_type="index", market="ETH-USD"), handler=on_message) + await client.subscribe( + params=CandlesParams(candle_type="index", market="ETH-USD", interval="PT1M"), handler=on_message + ) + + subscriptions_after = await client.list_subscriptions() + + LOGGER.info("Active subscriptions: %s", subscriptions_after) + + await stop_event.wait() + + +async def run_example(): + stop_event = asyncio.Event() + loop = asyncio.get_running_loop() + + def signal_handler(): + LOGGER.info("Signal received, stopping...") + stop_event.set() + + loop.add_signal_handler(SIGINT, signal_handler) + loop.add_signal_handler(SIGTERM, signal_handler) + + await subscribe_to_rpc_stream(stop_event) + + +if __name__ == "__main__": + run(main=run_example()) diff --git a/examples/utils.py b/examples/utils.py index 60cfabc..e4cf827 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -11,6 +11,7 @@ from x10.clients.blocking import BlockingTradingClient from x10.clients.rest import RestApiClient from x10.clients.stream import StreamClient +from x10.clients.streamrpc.streamrpc_client import StreamRpcClient from x10.config import get_config_by_name from x10.core.client_config import ClientConfig from x10.core.env_config import EnvConfig @@ -71,6 +72,10 @@ def create_stream_client(config: ClientConfig): return StreamClient(api_url=config.endpoints.stream_url) +def create_stream_rpc_client(config: ClientConfig): + return StreamRpcClient(api_url=config.endpoints.stream_rpc_url) + + def get_adjust_price_by_pct(config: TradingConfigModel): def adjust_price_by_pct(price: Decimal, pct: Decimal | int): return config.round_price(price + price * Decimal(pct) / 100) diff --git a/tests/clients/test_rest_api_client.py b/tests/clients/test_rest_api_client.py index 1b10d4e..376dc94 100644 --- a/tests/clients/test_rest_api_client.py +++ b/tests/clients/test_rest_api_client.py @@ -51,8 +51,8 @@ async def test_get_markets(aiohttp_server, create_btc_usd_market): "collateralAssetName": "USD", "collateralAssetPrecision": 6, "active": True, - "isRfq": True, - "isOffHours": True, + "isRfq": False, + "isOffHours": False, "marketStats": { "dailyVolume": "2410800.768021", "dailyVolumeBase": "37.94502", diff --git a/tests/clients/test_streamrpc_client.py b/tests/clients/test_streamrpc_client.py new file mode 100644 index 0000000..c53020a --- /dev/null +++ b/tests/clients/test_streamrpc_client.py @@ -0,0 +1,84 @@ +import asyncio +import json + +import pytest +import websockets +from hamcrest import assert_that, equal_to +from websockets import WebSocketServer + + +def get_url_from_server(server: WebSocketServer): + host, port = server.sockets[0].getsockname() # type: ignore[index] + return f"ws://{host}:{port}" + + +@pytest.mark.asyncio +async def test_candle_stream(): + from tests.fixtures.candle import create_candle_stream_rpc_message + from x10.clients.streamrpc.streamrpc_client import StreamRpcClient + from x10.clients.streamrpc.subscription_params import CandlesParams + + message_model = create_candle_stream_rpc_message() + received_messages: asyncio.Queue = asyncio.Queue() + + async def subscription_handler(msg): + await received_messages.put(msg) + + async def mock_server(websocket): + subscribe_msg_raw = await websocket.recv() + subscribe_msg = json.loads(subscribe_msg_raw) + + assert_that(subscribe_msg["method"], equal_to("subscribe")) + + await websocket.send( + json.dumps( + { + "id": subscribe_msg["id"], + "result": {"subscription": message_model.subscription}, + } + ) + ) + + await websocket.send(json.dumps(message_model.to_api_request_json())) + + unsubscribe_msg_raw = await websocket.recv() + unsubscribe_msg = json.loads(unsubscribe_msg_raw) + + assert_that(unsubscribe_msg["method"], equal_to("unsubscribe")) + + await websocket.send( + json.dumps( + { + "id": unsubscribe_msg["id"], + "result": {"method": "unsubscribe", "status": "OK"}, + } + ) + ) + + async with websockets.serve(mock_server, "127.0.0.1", 0) as server: + client = StreamRpcClient(api_url=get_url_from_server(server)) + await client.connect() + + subscription_params = CandlesParams(candle_type="last", market="BTC-USD", interval="PT1M") + subscription_id = await client.subscribe(params=subscription_params, handler=subscription_handler) + + msg = await asyncio.wait_for(received_messages.get(), timeout=5) + + await client.unsubscribe(subscription_id) + await client.close() + + assert_that( + msg.to_api_request_json(), + equal_to( + { + "type": "CANDLES", + "data": [ + {"o": "3458.64", "l": "3399.07", "h": "3476.89", "c": "3414.85", "v": "3.938", "T": 1721106000000} + ], + "error": None, + "ts": 1721283121979, + "seq": 1, + "subscription": "candles.last.BTC-USD.PT1M", + } + ), + ) diff --git a/tests/fixtures/candle.py b/tests/fixtures/candle.py index 39a538f..639f8b1 100644 --- a/tests/fixtures/candle.py +++ b/tests/fixtures/candle.py @@ -3,6 +3,7 @@ from x10.models.candle import CandleModel from x10.models.http import WrappedStreamResponseModel +from x10.models.stream_rpc import StreamRpcResponseModel def create_candle_stream_message(): @@ -20,3 +21,22 @@ def create_candle_stream_message(): ts=1721283121979, seq=1, ) + + +def create_candle_stream_rpc_message(): + return StreamRpcResponseModel( + type="CANDLES", + data=[ + CandleModel( + open=Decimal("3458.64"), + low=Decimal("3399.07"), + high=Decimal("3476.89"), + close=Decimal("3414.85"), + volume=Decimal("3.938"), + timestamp=1721106000000, + ) + ], + ts=1721283121979, + seq=1, + subscription="candles.last.BTC-USD.PT1M", + ) diff --git a/tests/fixtures/market.py b/tests/fixtures/market.py index 68dc5bb..8d0eaa9 100644 --- a/tests/fixtures/market.py +++ b/tests/fixtures/market.py @@ -15,8 +15,8 @@ def get_btc_usd_market_json_data(): "collateralAssetName": "USD", "collateralAssetPrecision": 6, "active": true, - "isRfq": true, - "isOffHours": true, + "isRfq": false, + "isOffHours": false, "marketStats": { "dailyVolume": "2410800.768021", "dailyVolumeBase": "37.94502", diff --git a/x10/clients/streamrpc/__init__.py b/x10/clients/streamrpc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/x10/clients/streamrpc/streamrpc_client.py b/x10/clients/streamrpc/streamrpc_client.py new file mode 100644 index 0000000..7abd2ca --- /dev/null +++ b/x10/clients/streamrpc/streamrpc_client.py @@ -0,0 +1,350 @@ +import asyncio +import json +import random +from typing import Any, Callable, Coroutine, TypeVar + +import websockets +from websockets import ConnectionClosed + +from x10.clients.streamrpc.streamrpc_dispatcher import ( + OnSequenceBreakCallback, + PendingRequestsMap, + RequestId, + StreamRpcDispatcher, +) +from x10.clients.streamrpc.subscription_params import ( + StreamMessageHandler, + SubscribeParams, + TopicId, + TopicSubscription, +) +from x10.errors import StreamRpcConnectionError, StreamRpcError, StreamRpcTimeoutError +from x10.utils.http import USER_AGENT, RequestHeader +from x10.utils.log import get_logger + +LOGGER = get_logger(__name__) +CONNECTION_LOOP_TASK_NAME = "x10-rpc-connection-loop" + +T = TypeVar("T") +OnReconnectCallback = Callable[[list[str]], Coroutine[Any, Any, None]] + + +class StreamRpcClient: + """ + EXPERIMENTAL! NOT TO BE USED IN PRODUCTION! TO BE IMPROVED IN THE UPCOMING VERSIONS. + + X10 WebSocket RPC client. + + Implements the JSON-RPC 2.0 like protocol over a WebSocket connection. + Supports automatic reconnection and transparent resubscription after connection loss. + + :param api_url: Full WebSocket URL. + :param on_reconnect: Optional async callback invoked after a successful reconnection. + :param on_sequence_break: Optional callback invoked when a gap is detected in the + connection-level ``seq`` counter, indicating that one or more stream + messages were dropped. + """ + + _ws: websockets.WebSocketClientProtocol | None + _ready: asyncio.Event + _connection_loop_task: asyncio.Task[None] | None + + async def connect(self): + """ + Starts the client's connection management loop and waits for the first connection to be established. + :raises StreamRpcConnectionError: If the initial connection is not established before the connect timeout. + """ + + if self._connection_loop_task is not None: + LOGGER.debug("Connection loop already running") + return + + LOGGER.debug("Connecting to %s", self._api_url) + + loop = asyncio.get_running_loop() + + self._is_stopped = False + self._connection_loop_task = loop.create_task(self._run_connection_loop(), name=CONNECTION_LOOP_TASK_NAME) + + try: + await asyncio.wait_for(self._ready.wait(), timeout=self._request_timeout) + except asyncio.TimeoutError as exc: + self._is_stopped = True + + if self._connection_loop_task: + self._connection_loop_task.cancel() + + raise StreamRpcConnectionError( + f"Connection to {self._api_url} timed out after {self._request_timeout}s" + ) from exc + + async def close(self): + """ + Stops the client and close the WebSocket connection. + """ + + self._is_stopped = True + self._ready.clear() + self._fail_pending(StreamRpcConnectionError("Client disconnected")) + + if self._ws is not None: + await self._ws.close() + self._ws = None + + if self._connection_loop_task is not None: + self._connection_loop_task.cancel() + + try: + await self._connection_loop_task + except (asyncio.CancelledError, Exception): + pass + + self._connection_loop_task = None + + async def ping(self): + """ + Sends a ping and wait for the server's acknowledgement. + """ + + await self._rpc("ping") + + async def list_subscriptions(self): + """ + Return the list of active subscription IDs as reported by the server. + """ + + result = await self._rpc("list-subscriptions") + return result["subscriptions"] + + async def subscribe(self, *, params: SubscribeParams[T], handler: StreamMessageHandler): + """ + Subscribe to a topic and register a handler for incoming messages. + + If a subscription with the same topic_id already exists it is replaced + (the server cancels the previous one automatically). + + :param params: Subscription parameters. + :param handler: Callable invoked for each message. May be sync or async. + :returns: The ``topic_id`` string. + """ + + await self._ready.wait() + + try: + self._subscriptions[params.topic_id] = TopicSubscription(params=params, handler=handler) + await self._rpc("subscribe", params=params.to_dict()) + except Exception: + LOGGER.error("Failed to subscribe to %s", params.topic_id) + self._subscriptions.pop(params.topic_id, None) + raise + + LOGGER.debug("Subscribed to %s", params.topic_id) + + return params.topic_id + + async def unsubscribe(self, topic_id: TopicId): + """ + Cancel an active subscription. + + :param topic_id: The string returned by :meth:`subscribe`. + :raises StreamRpcError: If no subscription with this ``topic_id`` exists. + """ + + subscription = self._subscriptions.get(topic_id) + + if subscription is None: + raise StreamRpcError(f"No active subscription: {topic_id}") + + await self._ready.wait() + + try: + self._subscriptions.pop(topic_id, None) + await self._rpc("unsubscribe", params=subscription.params.to_dict()) + except Exception: + LOGGER.error("Failed to unsubscribe from %s", topic_id) + raise + + LOGGER.debug("Unsubscribed from %s", topic_id) + + def __init__( + self, + *, + api_url: str, + on_reconnect: OnReconnectCallback | None = None, + on_sequence_break: OnSequenceBreakCallback | None = None, + ): + self._api_url = api_url + self._on_reconnect = on_reconnect + self._on_sequence_break = on_sequence_break + + self._request_timeout = 10 + self._reconnect_initial_delay = 1.0 + self._reconnect_max_delay = 10 + self._is_stopped = False + self._next_request_id = 0 + + self._ws = None + # Fires when a connection is established (and resubscription is done). + self._ready = asyncio.Event() + self._connection_loop_task = None + # Pending RPC requests (as futures) keyed by request id. + self._pending_requests: PendingRequestsMap = {} + # Active subscriptions keyed by topic id. + self._subscriptions: dict[TopicId, TopicSubscription] = {} + + self._dispatcher = StreamRpcDispatcher( + pending_requests=self._pending_requests, + subscriptions=self._subscriptions, + on_sequence_break=on_sequence_break, + ) + + async def __aenter__(self) -> "StreamRpcClient": + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + await self.close() + + def _get_next_request_id(self) -> RequestId: + self._next_request_id += 1 + return RequestId(self._next_request_id) + + async def _rpc(self, method: str, **kwargs: Any) -> dict[str, Any]: + """ + Send an RPC request and wait for its response. + """ + + if self._ws is None: + raise StreamRpcConnectionError("WebSocket connection is not open") + + request_id = self._get_next_request_id() + request: dict[str, Any] = {"method": method, "id": request_id, "jsonrpc": "2.0"} + + if kwargs: + request.update(kwargs) + + loop = asyncio.get_running_loop() + request_result: asyncio.Future[dict[str, Any]] = loop.create_future() + self._pending_requests[request_id] = request_result + + try: + await self._ws.send(json.dumps(request)) + # Shield the future so that cancelling the outer `wait_for` does not + # cancel the future itself (it is cleaned up in the `finally` block). + return await asyncio.wait_for(asyncio.shield(request_result), timeout=self._request_timeout) + except asyncio.TimeoutError as exc: + raise StreamRpcTimeoutError( + f"RPC request timed out: {method} (id={request_id}) after {self._request_timeout}s" + ) from exc + finally: + self._pending_requests.pop(request_id, None) + + def _fail_pending(self, exc: Exception) -> None: + """ + Resolve all pending RPC futures with an exception. + """ + + for request_result in list(self._pending_requests.values()): + if not request_result.done(): + request_result.set_exception(exc) + + self._pending_requests.clear() + + async def _run_connection_loop(self): + """ + Background task that maintains the connection (including reconnections) + and dispatches incoming messages. + """ + + reconnect_delay = self._reconnect_initial_delay + is_first_connection_attempt = True + + extra_headers: dict[str, str] = { + RequestHeader.USER_AGENT: USER_AGENT, + } + + async def handle_lost_connection(exc: Exception) -> bool: + nonlocal reconnect_delay + + self._ws = None + self._ready.clear() + self._fail_pending(StreamRpcConnectionError(f"Connection lost: {exc}")) + + LOGGER.warning("Connection lost: %s", exc) + + if self._is_stopped: + return False + + jitter = random.uniform(0.0, 1.0) + reconnect_after = min(reconnect_delay + jitter, self._reconnect_max_delay) + + LOGGER.debug("Reconnecting in %.1fs…", reconnect_after) + + await asyncio.sleep(reconnect_after) + reconnect_delay = min(reconnect_delay * 1.5, self._reconnect_max_delay) + + return True + + while not self._is_stopped: + try: + async with websockets.connect(self._api_url, extra_headers=extra_headers) as ws: + self._ws = ws + + LOGGER.debug("Connected to %s", self._api_url) + + # `seq` restarts at 0 on each new connection + self._dispatcher.reset_last_seq() + reconnect_delay = self._reconnect_initial_delay + + await self._resubscribe() + self._ready.set() + + if not is_first_connection_attempt and self._on_reconnect: + await self._on_reconnect(list(self._subscriptions)) + + is_first_connection_attempt = False + + async for raw in ws: + if isinstance(raw, str): + self._dispatcher.dispatch_raw(raw) + except asyncio.CancelledError: + break + except (ConnectionClosed, OSError, asyncio.TimeoutError) as exc: + should_try_to_reconnect = await handle_lost_connection(exc) + + if not should_try_to_reconnect: + break + except Exception as exc: + LOGGER.exception("Unexpected error in connection loop: %s", exc) + + self._ws = None + self._ready.clear() + self._fail_pending(StreamRpcConnectionError(str(exc))) + + if self._is_stopped: + break + + await asyncio.sleep(self._reconnect_initial_delay) + + self._ws = None + self._ready.clear() + self._fail_pending(StreamRpcConnectionError("Client stopped")) + + LOGGER.debug("Connection loop exited") + + async def _resubscribe(self) -> None: + """ + Replay all active subscriptions after a reconnection. + """ + + if self._ws is None or not self._subscriptions: + return + + LOGGER.debug("Resubscribing to topic(s): %s", ", ".join(list(self._subscriptions.keys()))) + + for topic_id, subscription in list(self._subscriptions.items()): + try: + await self._rpc("subscribe", params=subscription.params.to_dict()) + except Exception: + LOGGER.exception("Failed to resubscribe to %s", topic_id) + self._subscriptions.pop(topic_id, None) diff --git a/x10/clients/streamrpc/streamrpc_dispatcher.py b/x10/clients/streamrpc/streamrpc_dispatcher.py new file mode 100644 index 0000000..5881fd7 --- /dev/null +++ b/x10/clients/streamrpc/streamrpc_dispatcher.py @@ -0,0 +1,132 @@ +import asyncio +import json +from typing import Any, Callable, Coroutine, TypeAlias + +from x10.clients.streamrpc.subscription_params import TopicId, TopicSubscription +from x10.errors import StreamRpcServerError +from x10.models.stream_rpc import StreamRpcResponseModel +from x10.utils.log import get_logger + +LOGGER = get_logger(__name__) + +RequestId: TypeAlias = str +PendingRequestsMap: TypeAlias = dict[RequestId, asyncio.Future[dict[str, Any]]] +OnSequenceBreakCallback = Callable[[str, int, int], Coroutine[Any, Any, None]] + + +class StreamRpcDispatcher: + def __init__( + self, + *, + pending_requests: PendingRequestsMap, + subscriptions: dict[TopicId, TopicSubscription], + on_sequence_break: OnSequenceBreakCallback | None = None, + ) -> None: + # Last observed connection-level sequence (reset to `None` on each reconnect). + self._last_seq: int | None = None + self._pending_requests = pending_requests + self._subscriptions = subscriptions + self._on_sequence_break = on_sequence_break + + def reset_last_seq(self) -> None: + self._last_seq = None + + def dispatch_raw(self, raw: str) -> None: + """ + Parse a raw WebSocket text frame and route it to the right handler. + """ + + try: + msg: dict[str, Any] = json.loads(raw) + except json.JSONDecodeError: + LOGGER.warning("Received invalid JSON (%.120s…)", raw) + return + + # (1) JSON-RPC response + request_id: RequestId | None = msg.get("id") + + if request_id is not None: + request_result = self._pending_requests.get(str(request_id)) + + if request_result is None: + LOGGER.warning("Received response for unknown request id=%s", request_id) + return + + err = msg.get("error") + + if err: + request_result.set_exception( + StreamRpcServerError(code=err["code"], message=err["message"], data=err.get("data")) + ) + else: + request_result.set_result(msg["result"]) + + return + + # (2) Stream data + subscription_id: str | None = msg.get("subscription") + + if subscription_id is not None: + asyncio.ensure_future(self._dispatch_message(msg, subscription_id)) + return + + # (3) Unknown message + LOGGER.error("Unrecognised message shape: %s", raw) + + async def _dispatch_message(self, msg: dict[str, Any], subscription_id: str) -> None: + """ + Deserialize a stream message and invoke the registered handler. + """ + + subscription = self._subscriptions.get(subscription_id) + + if subscription is None: + LOGGER.warning("Received message for unknown subscription id=%s", subscription_id) + return + + try: + msg_model: StreamRpcResponseModel[Any] = StreamRpcResponseModel.model_validate(msg) + except Exception as exc: + LOGGER.exception("Failed to validate message for subscription %s: %s", subscription_id, exc) + return + + if self._last_seq is not None and msg_model.seq != self._last_seq + 1: + LOGGER.warning( + "Sequence break detected for subscription %s: last_seq=%s, msg_seq=%s", + subscription_id, + self._last_seq, + msg_model.seq, + ) + + if self._on_sequence_break: + try: + sequence_break_result = self._on_sequence_break(subscription_id, self._last_seq, msg_model.seq) + + if asyncio.iscoroutine(sequence_break_result): + await sequence_break_result + except Exception: + LOGGER.exception("Unhandled exception in `on_sequence_break` callback") + + self._last_seq = msg_model.seq + + try: + deserialized_data = subscription.params.deserialize_data(msg_model.data, msg_model.type) + except Exception as exc: + LOGGER.exception( + "Failed to deserialize message for subscription %s (type=%s, seq=%s): %s", + subscription_id, + msg_model.type, + msg_model.seq, + exc, + ) + return + + msg_model_with_deserialized_data = msg_model.model_copy(update={"data": deserialized_data}) + + try: + subscription_handler_result = subscription.handler(msg_model_with_deserialized_data) + + if asyncio.iscoroutine(subscription_handler_result): + await subscription_handler_result + except Exception: + LOGGER.exception("Unhandled exception in handler for subscription %s", subscription_id) diff --git a/x10/clients/streamrpc/subscription_params.py b/x10/clients/streamrpc/subscription_params.py new file mode 100644 index 0000000..1db682d --- /dev/null +++ b/x10/clients/streamrpc/subscription_params.py @@ -0,0 +1,252 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, Coroutine, Generic, TypeAlias, TypeVar + +from x10.errors import ValidationError +from x10.models.candle import CandleModel +from x10.models.funding_rate import FundingRateModel +from x10.models.orderbook import OrderbookUpdateModel +from x10.models.stream_rpc import ( + StreamRpcAccountBalanceModel, + StreamRpcAccountDepositUpdateModel, + StreamRpcAccountOrdersModel, + StreamRpcAccountPositionsModel, + StreamRpcAccountSpotBalancesModel, + StreamRpcAccountTradesModel, + StreamRpcAccountWithdrawalUpdateModel, + StreamRpcOrderbookUpdateModel, + StreamRpcPriceModel, + StreamRpcResponseModel, +) +from x10.models.trade import PublicTradeModel + +T = TypeVar("T") +TopicId: TypeAlias = str +StreamMessageHandler = Callable[[StreamRpcResponseModel[Any]], Coroutine[Any, Any, None] | None] + + +class SubscribeParams(ABC, Generic[T]): + """ + Base class for all subscription parameter types. + """ + + @property + @abstractmethod + def topic_id(self) -> TopicId: + """ + The unique topic identifier (e.g. ``trades.BTC-USD``). + """ + + @abstractmethod + def to_dict(self) -> dict[str, Any]: + """ + Serialize to the JSON structure expected by the RPC ``subscribe`` call. + """ + + @abstractmethod + def deserialize_data(self, data: Any, msg_type: str) -> T: + """ + Convert a raw JSON payload into the typed domain model ``T``. + + :param data: The raw dict from the ``data`` field of the envelope. + :param msg_type: The ``type`` field of the envelope, used by multi-type subscriptions + (e.g. ``account``) to select the correct model. + """ + + +class TradesParams(SubscribeParams[list[PublicTradeModel]]): + """ + Subscribe to public trade events for a market (or all markets). + """ + + def __init__(self, *, market: str | None = None) -> None: + self.market = market + + @property + def topic_id(self) -> TopicId: + return f"trades.{self.market or 'all'}" + + def to_dict(self) -> dict[str, Any]: + return {"scope": "trades", "selector": {"market": self.market}} + + def deserialize_data(self, data: list[dict[str, Any]], msg_type: str) -> list[PublicTradeModel]: + return [PublicTradeModel.model_validate(item) for item in data] + + +@dataclass +class TopicSubscription: + params: SubscribeParams[Any] + handler: StreamMessageHandler + + +class OrderbooksParams(SubscribeParams[OrderbookUpdateModel]): + """ + Subscribe to order book snapshots and delta updates. + + :param market: Market symbol or ``None`` for all markets. + :param depth: ``"full"`` (default) for the full order book, or ``"1"`` for best bid/ask only. + :param rfq_only: If ``True``, only include RFQ (request-for-quote) levels. Only valid when ``depth="full"``. + """ + + def __init__(self, *, market: str | None = None, depth: str = "full", rfq_only: bool = False) -> None: + if depth not in ("full", "1"): + raise ValidationError(f"`depth` must be `full` or `1`, got {depth!r}") + + if rfq_only and depth != "full": + raise ValidationError("`rfq_only` is only valid when depth is `full`") + + self.market = market + self.depth = depth + self.rfq_only = rfq_only + + @property + def topic_id(self) -> str: + if self.depth == "1": + return f"orderbooks.1.{self.market or 'all'}" + + return f"orderbooks.{self.market or 'all'}{'.rfq' if self.rfq_only else ''}" + + def to_dict(self) -> dict[str, Any]: + return { + "scope": "orderbooks", + "selector": { + "market": self.market, + "depth": self.depth, + "rfqOnly": self.rfq_only, + }, + } + + def deserialize_data(self, data: dict[str, Any], msg_type: str) -> StreamRpcOrderbookUpdateModel: + return StreamRpcOrderbookUpdateModel.model_validate(data) + + +class FundingRatesParams(SubscribeParams[FundingRateModel]): + """ + Subscribe to funding rate updates for a market (or all markets). + + :param market: Market symbol or ``None`` for all markets. + """ + + def __init__(self, *, market: str | None = None) -> None: + self.market = market + + @property + def topic_id(self) -> str: + return f"funding-rates.{self.market or 'all'}" + + def to_dict(self) -> dict[str, Any]: + return {"scope": "funding-rates", "selector": {"market": self.market}} + + def deserialize_data(self, data: dict[str, Any], msg_type: str) -> FundingRateModel: + return FundingRateModel.model_validate(data) + + +class PricesParams(SubscribeParams[StreamRpcPriceModel]): + """ + Subscribe to mark / index price updates for a market (or all markets) + + :param price_type: ``"mark"`` or ``"index"``. + :param market: Market symbol or ``None`` for all markets. + """ + + def __init__(self, *, price_type: str, market: str | None = None) -> None: + if price_type not in ("mark", "index"): + raise ValidationError(f"`price_type` must be `mark` or `index`, got {price_type!r}") + + self.price_type = price_type + self.market = market + + @property + def topic_id(self) -> str: + return f"prices.{self.price_type}.{self.market or 'all'}" + + def to_dict(self) -> dict[str, Any]: + return { + "scope": "prices", + "selector": {"type": self.price_type, "market": self.market}, + } + + def deserialize_data(self, data: dict[str, Any], msg_type: str) -> StreamRpcPriceModel: + return StreamRpcPriceModel.model_validate(data) + + +class CandlesParams(SubscribeParams[list[CandleModel]]): + """ + Subscribe to candles OHLC (`mark` or `index`) / OHLCV (`last`) for a market and interval. + + :param candle_type: ``"mark"``, ``"index"``, or ``"last"``. + :param market: Market symbol. + :param interval: ISO-8601 duration. + """ + + def __init__(self, *, candle_type: str, market: str, interval: str) -> None: + if candle_type not in ("mark", "index", "last"): + raise ValidationError(f"`candle_type` must be `mark`, `index`, or `last`, got {candle_type!r}") + + self.candle_type = candle_type + self.market = market + self.interval = interval + + @property + def topic_id(self) -> str: + return f"candles.{self.candle_type}.{self.market}.{self.interval}" + + def to_dict(self) -> dict[str, Any]: + return { + "scope": "candles", + "selector": {"type": self.candle_type, "market": self.market, "interval": self.interval}, + } + + def deserialize_data(self, data: list[dict[str, Any]], msg_type: str) -> list[CandleModel]: + return [CandleModel.model_validate(item) for item in data] + + +StreamRpcAccountUpdateType: TypeAlias = ( + StreamRpcAccountPositionsModel + | StreamRpcAccountOrdersModel + | StreamRpcAccountTradesModel + | StreamRpcAccountBalanceModel + | StreamRpcAccountSpotBalancesModel + | StreamRpcAccountDepositUpdateModel + | StreamRpcAccountWithdrawalUpdateModel +) + + +class _AccountParams(SubscribeParams[StreamRpcAccountUpdateType]): + """ + NOT SUPPORTED DUE TO AUTH ISSUES. TO BE FIXED IN THE UPCOMING VERSIONS. + + Subscribe to the private account stream. + """ + + def __init__(self, *, account: str) -> None: + self.account = account + + @property + def topic_id(self) -> str: + return f"account.{self.account}" + + def to_dict(self) -> dict[str, Any]: + return { + "scope": "account", + "selector": {"account": self.account}, + } + + def deserialize_data(self, data: dict[str, Any], msg_type: str) -> StreamRpcAccountUpdateType: + match msg_type: + case "ACCOUNT.POSITION": + return StreamRpcAccountPositionsModel.model_validate(data) + case "ACCOUNT.ORDER": + return StreamRpcAccountOrdersModel.model_validate(data) + case "ACCOUNT.TRADE": + return StreamRpcAccountTradesModel.model_validate(data) + case "ACCOUNT.BALANCE": + return StreamRpcAccountBalanceModel.model_validate(data) + case "ACCOUNT.SPOT_BALANCE": + return StreamRpcAccountSpotBalancesModel.model_validate(data) + case "ACCOUNT.DEPOSIT": + return StreamRpcAccountDepositUpdateModel.model_validate(data) + case "ACCOUNT.WITHDRAWAL": + return StreamRpcAccountWithdrawalUpdateModel.model_validate(data) + case _: + raise ValidationError(f"Unknown account stream message type: {msg_type!r}") diff --git a/x10/config.py b/x10/config.py index 9b1ced8..a39521a 100644 --- a/x10/config.py +++ b/x10/config.py @@ -22,6 +22,7 @@ api_base_url="https://api.starknet.sepolia.extended.exchange/api/v1", api_base_order_management_url="https://api.starknet.sepolia.extended.exchange/api/v1", stream_url="wss://api.starknet.sepolia.extended.exchange/stream.extended.exchange/v1", + stream_rpc_url="wss://api.starknet.sepolia.extended.exchange/stream.extended.exchange/v2/rpc", onboarding_url="https://api.starknet.sepolia.extended.exchange", vault_asset_name="XVS", ), @@ -38,6 +39,7 @@ api_base_url="https://api.starknet.extended.exchange/api/v1", api_base_order_management_url="https://api.starknet.extended.exchange/api/v1", stream_url="wss://api.starknet.extended.exchange/stream.extended.exchange/v1", + stream_rpc_url="wss://api.starknet.extended.exchange/stream.extended.exchange/v2/rpc", onboarding_url="https://api.starknet.extended.exchange", vault_asset_name="XVS", ), diff --git a/x10/core/client_config.py b/x10/core/client_config.py index b066788..316dfb0 100644 --- a/x10/core/client_config.py +++ b/x10/core/client_config.py @@ -30,6 +30,7 @@ class EndpointsConfig: api_base_url: str api_base_order_management_url: str stream_url: str + stream_rpc_url: str onboarding_url: str vault_asset_name: str diff --git a/x10/errors.py b/x10/errors.py index 194c911..5823ff2 100644 --- a/x10/errors.py +++ b/x10/errors.py @@ -10,6 +10,44 @@ class NotSupportedError(SdkError, NotImplementedError): pass +class StreamRpcError(SdkError): + pass + + +class StreamRpcServerError(StreamRpcError): + PARSE_ERROR = -32700 + INVALID_REQUEST = -32600 + METHOD_NOT_FOUND = -32601 + INVALID_PARAMS = -32602 + INTERNAL_ERROR = -32603 + UNAUTHORIZED = -32001 + + def __init__(self, code: int, message: str, data: object = None) -> None: + super().__init__(f"[{code}] {message}") + + self.code = code + self.message = message + self.data = data + + +class StreamRpcConnectionError(StreamRpcError): + """ + WebSocket connection is unavailable. + """ + + +class StreamRpcTimeoutError(StreamRpcError): + """ + RPC request times out waiting for a response. + """ + + +class StreamRpcParseError(StreamRpcError): + """ + Incoming message cannot be parsed as JSON. + """ + + class ApiError(SdkError): pass diff --git a/x10/models/deposit.py b/x10/models/deposit.py new file mode 100644 index 0000000..2356439 --- /dev/null +++ b/x10/models/deposit.py @@ -0,0 +1,18 @@ +from decimal import Decimal + +from strenum import StrEnum + +from x10.models.base import X10BaseModel + + +class DepositStatus(StrEnum): + CREATED = "CREATED" + PROCESSED = "PROCESSED" + REJECTED = "REJECTED" + + +class DepositStatusUpdateModel(X10BaseModel): + asset_id: int + amount: Decimal + timestamp: int + status: DepositStatus diff --git a/x10/models/stream_rpc.py b/x10/models/stream_rpc.py new file mode 100644 index 0000000..527c2c1 --- /dev/null +++ b/x10/models/stream_rpc.py @@ -0,0 +1,69 @@ +from decimal import Decimal +from typing import Generic, TypeVar + +from pydantic import AliasChoices, Field + +from x10.models.balance import BalanceModel, SpotBalanceModel +from x10.models.base import X10BaseModel +from x10.models.deposit import DepositStatusUpdateModel +from x10.models.order import OpenOrderModel +from x10.models.orderbook import OrderbookUpdateModel +from x10.models.position import PositionModel +from x10.models.trade import AccountTradeModel +from x10.models.withdrawal import WithdrawalStatusUpdateModel + +T = TypeVar("T") + + +class StreamRpcResponseModel(X10BaseModel, Generic[T]): + type: str + data: T + error: str | None = None + ts: int + seq: int + subscription: str + + +class StreamRpcOrderbookUpdateModel(OrderbookUpdateModel): + depth: str = Field(validation_alias=AliasChoices("depth", "d"), serialization_alias="d") + + +class StreamRpcPriceModel(X10BaseModel): + market: str = Field(validation_alias=AliasChoices("market", "m"), serialization_alias="m") + price: Decimal = Field(validation_alias=AliasChoices("price", "p"), serialization_alias="p") + ts: int + + +class StreamRpcAccountPositionsModel(X10BaseModel): + is_snapshot: bool + positions: list[PositionModel] + + +class StreamRpcAccountOrdersModel(X10BaseModel): + is_snapshot: bool + orders: list[OpenOrderModel] + + +class StreamRpcAccountTradesModel(X10BaseModel): + is_snapshot: bool + trades: list[AccountTradeModel] + + +class StreamRpcAccountBalanceModel(X10BaseModel): + is_snapshot: bool + balance: BalanceModel + + +class StreamRpcAccountSpotBalancesModel(X10BaseModel): + is_snapshot: bool + spot_balances: list[SpotBalanceModel] + + +class StreamRpcAccountDepositUpdateModel(X10BaseModel): + is_snapshot: bool + deposit: DepositStatusUpdateModel + + +class StreamRpcAccountWithdrawalUpdateModel(X10BaseModel): + is_snapshot: bool + withdrawal: WithdrawalStatusUpdateModel diff --git a/x10/models/withdrawal.py b/x10/models/withdrawal.py index 5886477..33d8b4e 100644 --- a/x10/models/withdrawal.py +++ b/x10/models/withdrawal.py @@ -1,8 +1,18 @@ from decimal import Decimal +from strenum import StrEnum + from x10.models.base import HexValue, SettlementSignatureModel, X10BaseModel +class WithdrawalStatus(StrEnum): + CREATED = "CREATED" + REJECTED = "REJECTED" + IN_PROGRESS = "IN_PROGRESS" + READY_FOR_CLAIM = "READY_FOR_CLAIM" + COMPLETED = "COMPLETED" + + class TimestampModel(X10BaseModel): seconds: int @@ -25,3 +35,11 @@ class WithdrawalRequestModel(X10BaseModel): chain_id: str quote_id: str | None = None asset: str + + +class WithdrawalStatusUpdateModel(X10BaseModel): + id: int + asset_id: int + amount: Decimal + status: WithdrawalStatus + reason: str | None = None