diff --git a/README.md b/README.md index f3c4ba81..480e8de2 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,11 @@ from python_opensky import OpenSky, StatesResponse async def main() -> None: """Show example of fetching all flight states.""" async with OpenSky() as opensky: + # Optional: authenticate for higher rate limits + # await opensky.authenticate( + # client_id="your_client_id", + # client_secret="your_client_secret", + # ) states: StatesResponse = await opensky.get_states() print(states) diff --git a/src/python_opensky/const.py b/src/python_opensky/const.py index e4529f8b..2b0bad1e 100644 --- a/src/python_opensky/const.py +++ b/src/python_opensky/const.py @@ -46,3 +46,6 @@ class AircraftCategory(int, Enum): MAX_LATITUDE = "lamax" MIN_LONGITUDE = "lomin" MAX_LONGITUDE = "lomax" + +TOKEN_URL = "https://auth.opensky-network.org/auth/realms/opensky-network/protocol/openid-connect/token" # noqa: S105 +TOKEN_REFRESH_MARGIN = 30 diff --git a/src/python_opensky/opensky.py b/src/python_opensky/opensky.py index 07f9e002..8afff507 100644 --- a/src/python_opensky/opensky.py +++ b/src/python_opensky/opensky.py @@ -10,11 +10,18 @@ from importlib import metadata from typing import TYPE_CHECKING, Any, cast -from aiohttp import BasicAuth, ClientError, ClientResponseError, ClientSession -from aiohttp.hdrs import METH_GET +from aiohttp import ClientError, ClientResponseError, ClientSession +from aiohttp.hdrs import METH_GET, METH_POST from yarl import URL -from .const import MAX_LATITUDE, MAX_LONGITUDE, MIN_LATITUDE, MIN_LONGITUDE +from .const import ( + MAX_LATITUDE, + MAX_LONGITUDE, + MIN_LATITUDE, + MIN_LONGITUDE, + TOKEN_REFRESH_MARGIN, + TOKEN_URL, +) from .exceptions import ( OpenSkyConnectionError, OpenSkyError, @@ -29,6 +36,16 @@ VERSION = metadata.version(__package__) +@dataclass +class _OAuthSession: + """OAuth2 client credentials and the access token they hold.""" + + client_id: str + client_secret: str + token: str | None = None + expires_at: datetime | None = None + + @dataclass class OpenSky: """Main class for handling connections with OpenSky.""" @@ -40,21 +57,23 @@ class OpenSky: timezone = UTC _close_session: bool = False _credit_usage: dict[datetime, int] = field(default_factory=dict) - _auth: BasicAuth | None = None + _oauth: _OAuthSession | None = None _contributing_user: bool = False async def authenticate( self, - auth: BasicAuth, + client_id: str, + client_secret: str, *, contributing_user: bool = False, ) -> None: """Authenticate the user.""" - self._auth = auth + self._oauth = _OAuthSession(client_id=client_id, client_secret=client_secret) try: + await self._refresh_token() await self.get_states(bounding_box=BoundingBox(0.0, 0.0, 1.0, 1.0)) except OpenSkyUnauthenticatedError as exc: - self._auth = None + self._oauth = None raise OpenSkyUnauthenticatedError from exc self._contributing_user = contributing_user if contributing_user: @@ -70,7 +89,64 @@ def is_contributing_user(self) -> bool: @property def is_authenticated(self) -> bool: """Return if the user is correctly authenticated.""" - return self._auth is not None + return self._oauth is not None + + async def _refresh_token(self) -> None: + """Refresh the OAuth2 access token.""" + assert self._oauth is not None # noqa: S101 — callers guard + if self.session is None: + self.session = ClientSession() + self._close_session = True + + try: + async with asyncio.timeout(self.request_timeout): + response = await self.session.request( + METH_POST, + TOKEN_URL, + data={ + "grant_type": "client_credentials", + "client_id": self._oauth.client_id, + "client_secret": self._oauth.client_secret, + }, + ) + except TimeoutError as exception: + msg = "Timeout occurred while connecting to the OpenSky API" + raise OpenSkyConnectionError(msg) from exception + except ( + ClientError, + ClientResponseError, + socket.gaierror, + ) as exception: + msg = "Error occurred while communicating with OpenSky API" + raise OpenSkyConnectionError(msg) from exception + + if response.status == 401: + raise OpenSkyUnauthenticatedError + + try: + response.raise_for_status() + except ClientResponseError as exception: + msg = "Error occurred while communicating with OpenSky API" + raise OpenSkyConnectionError(msg) from exception + + token_data = await response.json() + self._oauth.token = token_data["access_token"] + self._oauth.expires_at = datetime.now(UTC) + timedelta( + seconds=token_data["expires_in"] - TOKEN_REFRESH_MARGIN, + ) + + async def _get_access_token(self) -> str | None: + """Get a valid access token, refreshing if needed.""" + if self._oauth is None: + return None + if ( + self._oauth.token + and self._oauth.expires_at + and self._oauth.expires_at > datetime.now(UTC) + ): + return self._oauth.token + await self._refresh_token() + return self._oauth.token async def _request( self, @@ -116,12 +192,15 @@ async def _request( self.session = ClientSession() self._close_session = True + token = await self._get_access_token() + if token is not None: + headers["Authorization"] = f"Bearer {token}" + try: async with asyncio.timeout(self.request_timeout): response = await self.session.request( METH_GET, url.with_query(data), - auth=self._auth, headers=headers, ) response.raise_for_status() @@ -185,7 +264,7 @@ async def get_states( async def get_own_states(self, time: int = 0) -> StatesResponse: """Retrieve state vectors from your own sensors.""" - if not self._auth: + if self._oauth is None: raise OpenSkyUnauthenticatedError params = { "time": time, diff --git a/tests/test_states.py b/tests/test_states.py index 78cd5bee..c033eb15 100644 --- a/tests/test_states.py +++ b/tests/test_states.py @@ -2,10 +2,12 @@ import asyncio from dataclasses import asdict +from datetime import UTC, datetime +from typing import Any import aiohttp import pytest -from aiohttp import BasicAuth, ClientError +from aiohttp import ClientError from aiohttp.web_request import BaseRequest from aresponses import Response, ResponsesMockServer from syrupy.assertion import SnapshotAssertion @@ -22,6 +24,29 @@ from . import load_fixture OPENSKY_URL = "opensky-network.org" +TOKEN_HOST = "auth.opensky-network.org" # noqa: S105 +TOKEN_PATH = "/auth/realms/opensky-network/protocol/openid-connect/token" # noqa: S105 +TOKEN_RESPONSE = '{"access_token": "test-token", "expires_in": 1800}' # noqa: S105 + + +def _add_token_mock( + aresponses: ResponsesMockServer, + *, + repeat: int = 1, + status: int = 200, +) -> None: + """Add a token endpoint mock.""" + aresponses.add( + TOKEN_HOST, + TOKEN_PATH, + "POST", + aresponses.Response( + status=status, + headers={"Content-Type": "application/json"}, + text=TOKEN_RESPONSE, + ), + repeat=repeat, + ) async def test_states( @@ -73,6 +98,7 @@ async def test_own_states( aresponses: ResponsesMockServer, ) -> None: """Test retrieving own states.""" + _add_token_mock(aresponses) aresponses.add( OPENSKY_URL, "/api/states/all", @@ -96,7 +122,8 @@ async def test_own_states( async with aiohttp.ClientSession() as session: opensky = OpenSky(session=session) await opensky.authenticate( - BasicAuth(login="test", password="test"), + client_id="test_id", + client_secret="test_secret", contributing_user=True, ) response: StatesResponse = await opensky.get_own_states() @@ -110,6 +137,7 @@ async def test_unavailable_own_states( aresponses: ResponsesMockServer, ) -> None: """Test retrieving no own states.""" + _add_token_mock(aresponses) aresponses.add( OPENSKY_URL, "/api/states/all", @@ -133,7 +161,8 @@ async def test_unavailable_own_states( async with aiohttp.ClientSession() as session: opensky = OpenSky(session=session) await opensky.authenticate( - BasicAuth(login="test", password="test"), + client_id="test_id", + client_secret="test_secret", contributing_user=True, ) response: StatesResponse = await opensky.get_own_states() @@ -236,12 +265,13 @@ async def response_handler(_: BaseRequest) -> Response: async def test_auth(aresponses: ResponsesMockServer) -> None: """Test request authentication.""" + _add_token_mock(aresponses, repeat=1) def response_handler(request: BaseRequest) -> Response: """Response handler for this test.""" assert request.headers assert request.headers["Authorization"] - assert request.headers["Authorization"] == "Basic dGVzdDp0ZXN0" + assert request.headers["Authorization"] == "Bearer test-token" return aresponses.Response( status=200, headers={"Content-Type": "application/json"}, @@ -258,7 +288,10 @@ def response_handler(request: BaseRequest) -> Response: async with aiohttp.ClientSession() as session: opensky = OpenSky(session=session) - await opensky.authenticate(BasicAuth(login="test", password="test")) + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) await opensky.get_states() await opensky.close() @@ -266,20 +299,23 @@ def response_handler(request: BaseRequest) -> Response: async def test_unauthorized(aresponses: ResponsesMockServer) -> None: """Test request authentication.""" aresponses.add( - OPENSKY_URL, - "/api/states/all", - "GET", + TOKEN_HOST, + TOKEN_PATH, + "POST", aresponses.Response( status=401, headers={"Content-Type": "application/json"}, - text=load_fixture("states.json"), + text="{}", ), ) async with aiohttp.ClientSession() as session: opensky = OpenSky(session=session) try: - await opensky.authenticate(BasicAuth(login="test", password="test")) + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) pytest.fail("Should've thrown exception") except OpenSkyUnauthenticatedError: pass @@ -289,6 +325,7 @@ async def test_unauthorized(aresponses: ResponsesMockServer) -> None: async def test_user_credits(aresponses: ResponsesMockServer) -> None: """Test authenticated user credits.""" + _add_token_mock(aresponses, repeat=2) aresponses.add( OPENSKY_URL, "/api/states/all", @@ -303,10 +340,14 @@ async def test_user_credits(aresponses: ResponsesMockServer) -> None: async with aiohttp.ClientSession() as session: opensky = OpenSky(session=session) assert opensky.opensky_credits == 400 - await opensky.authenticate(BasicAuth(login="test", password="test")) + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) assert opensky.opensky_credits == 4000 await opensky.authenticate( - BasicAuth(login="test", password="test"), + client_id="test_id", + client_secret="test_secret", contributing_user=True, ) assert opensky.opensky_credits == 8000 @@ -397,3 +438,144 @@ async def test_calculating_credit_usage() -> None: max_longitude=10.9, ) assert opensky.calculate_credit_costs(bounding_box) == 4 + + +async def test_token_refresh(aresponses: ResponsesMockServer) -> None: + """Test that token is refreshed when expired.""" + _add_token_mock(aresponses, repeat=2) + aresponses.add( + OPENSKY_URL, + "/api/states/all", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text=load_fixture("states.json"), + ), + repeat=2, + ) + async with aiohttp.ClientSession() as session: + opensky = OpenSky(session=session) + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) + # Expire the token + assert opensky._oauth is not None # noqa: SLF001 + opensky._oauth.expires_at = datetime(2020, 1, 1, tzinfo=UTC) # noqa: SLF001 + await opensky.get_states() + await opensky.close() + + +async def test_token_refresh_new_session(aresponses: ResponsesMockServer) -> None: + """Test that _refresh_token creates a session if none exists.""" + _add_token_mock(aresponses) + aresponses.add( + OPENSKY_URL, + "/api/states/all", + "GET", + aresponses.Response( + status=200, + headers={"Content-Type": "application/json"}, + text=load_fixture("states.json"), + ), + ) + async with OpenSky() as opensky: + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) + assert opensky.session + + +async def test_token_refresh_timeout(aresponses: ResponsesMockServer) -> None: + """Test token refresh timeout.""" + + async def response_handler(_: BaseRequest) -> Response: + await asyncio.sleep(2) + return aresponses.Response(body="Timeout") + + aresponses.add( + TOKEN_HOST, + TOKEN_PATH, + "POST", + response_handler, + ) + + async with aiohttp.ClientSession() as session: + opensky = OpenSky(session=session, request_timeout=1) + with pytest.raises(OpenSkyConnectionError): + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) + await opensky.close() + + +async def test_token_refresh_connection_error() -> None: + """Test token refresh connection error.""" + async with aiohttp.ClientSession() as session: + opensky = OpenSky(session=session) + # Patch session.request to raise ClientError + original_request = session.request + + async def mock_request(*_args: Any, **_kwargs: Any) -> None: + raise ClientError + + session.request = mock_request # type: ignore[assignment] + with pytest.raises(OpenSkyConnectionError): + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) + session.request = original_request # type: ignore[assignment] + await opensky.close() + + +async def test_token_refresh_server_error( + aresponses: ResponsesMockServer, +) -> None: + """Test token refresh with server error response.""" + aresponses.add( + TOKEN_HOST, + TOKEN_PATH, + "POST", + aresponses.Response( + status=500, + headers={"Content-Type": "application/json"}, + text="{}", + ), + ) + + async with aiohttp.ClientSession() as session: + opensky = OpenSky(session=session) + with pytest.raises(OpenSkyConnectionError): + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) + await opensky.close() + + +async def test_api_returns_401(aresponses: ResponsesMockServer) -> None: + """Test that a 401 from the API raises OpenSkyUnauthenticatedError.""" + _add_token_mock(aresponses) + aresponses.add( + OPENSKY_URL, + "/api/states/all", + "GET", + aresponses.Response( + status=401, + headers={"Content-Type": "application/json"}, + text="{}", + ), + ) + async with aiohttp.ClientSession() as session: + opensky = OpenSky(session=session) + with pytest.raises(OpenSkyUnauthenticatedError): + await opensky.authenticate( + client_id="test_id", + client_secret="test_secret", + ) + assert opensky.is_authenticated is False + await opensky.close()