Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ from python_opensky import OpenSky, StatesResponse
async def main() -> None:
"""Show example of fetching all flight states."""
async with OpenSky() as opensky:
# Optional: authenticate for higher rate limits
# await opensky.authenticate(
# client_id="your_client_id",
# client_secret="your_client_secret",
# )
states: StatesResponse = await opensky.get_states()
print(states)

Expand Down
3 changes: 3 additions & 0 deletions src/python_opensky/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ class AircraftCategory(int, Enum):
MAX_LATITUDE = "lamax"
MIN_LONGITUDE = "lomin"
MAX_LONGITUDE = "lomax"

TOKEN_URL = "https://auth.opensky-network.org/auth/realms/opensky-network/protocol/openid-connect/token" # noqa: S105
TOKEN_REFRESH_MARGIN = 30
99 changes: 89 additions & 10 deletions src/python_opensky/opensky.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,18 @@
from importlib import metadata
from typing import TYPE_CHECKING, Any, cast

from aiohttp import BasicAuth, ClientError, ClientResponseError, ClientSession
from aiohttp.hdrs import METH_GET
from aiohttp import ClientError, ClientResponseError, ClientSession
from aiohttp.hdrs import METH_GET, METH_POST
from yarl import URL

from .const import MAX_LATITUDE, MAX_LONGITUDE, MIN_LATITUDE, MIN_LONGITUDE
from .const import (
MAX_LATITUDE,
MAX_LONGITUDE,
MIN_LATITUDE,
MIN_LONGITUDE,
TOKEN_REFRESH_MARGIN,
TOKEN_URL,
)
from .exceptions import (
OpenSkyConnectionError,
OpenSkyError,
Expand All @@ -29,6 +36,16 @@
VERSION = metadata.version(__package__)


@dataclass
class _OAuthSession:
"""OAuth2 client credentials and the access token they hold."""

client_id: str
client_secret: str
token: str | None = None
expires_at: datetime | None = None


@dataclass
class OpenSky:
"""Main class for handling connections with OpenSky."""
Expand All @@ -40,21 +57,23 @@ class OpenSky:
timezone = UTC
_close_session: bool = False
_credit_usage: dict[datetime, int] = field(default_factory=dict)
_auth: BasicAuth | None = None
_oauth: _OAuthSession | None = None
_contributing_user: bool = False

async def authenticate(
self,
auth: BasicAuth,
client_id: str,
client_secret: str,
*,
contributing_user: bool = False,
) -> None:
"""Authenticate the user."""
self._auth = auth
self._oauth = _OAuthSession(client_id=client_id, client_secret=client_secret)
try:
await self._refresh_token()
await self.get_states(bounding_box=BoundingBox(0.0, 0.0, 1.0, 1.0))
except OpenSkyUnauthenticatedError as exc:
self._auth = None
self._oauth = None
raise OpenSkyUnauthenticatedError from exc
self._contributing_user = contributing_user
if contributing_user:
Expand All @@ -70,7 +89,64 @@ def is_contributing_user(self) -> bool:
@property
def is_authenticated(self) -> bool:
"""Return if the user is correctly authenticated."""
return self._auth is not None
return self._oauth is not None

async def _refresh_token(self) -> None:
"""Refresh the OAuth2 access token."""
assert self._oauth is not None # noqa: S101 — callers guard
if self.session is None:
self.session = ClientSession()
self._close_session = True

try:
async with asyncio.timeout(self.request_timeout):
response = await self.session.request(
METH_POST,
TOKEN_URL,
data={
"grant_type": "client_credentials",
"client_id": self._oauth.client_id,
"client_secret": self._oauth.client_secret,
},
)
except TimeoutError as exception:
msg = "Timeout occurred while connecting to the OpenSky API"
raise OpenSkyConnectionError(msg) from exception
except (
ClientError,
ClientResponseError,
socket.gaierror,
) as exception:
msg = "Error occurred while communicating with OpenSky API"
raise OpenSkyConnectionError(msg) from exception

if response.status == 401:
raise OpenSkyUnauthenticatedError

try:
response.raise_for_status()
except ClientResponseError as exception:
msg = "Error occurred while communicating with OpenSky API"
raise OpenSkyConnectionError(msg) from exception

token_data = await response.json()
self._oauth.token = token_data["access_token"]
self._oauth.expires_at = datetime.now(UTC) + timedelta(
seconds=token_data["expires_in"] - TOKEN_REFRESH_MARGIN,
)

async def _get_access_token(self) -> str | None:
"""Get a valid access token, refreshing if needed."""
if self._oauth is None:
return None
if (
self._oauth.token
and self._oauth.expires_at
and self._oauth.expires_at > datetime.now(UTC)
):
return self._oauth.token
await self._refresh_token()
return self._oauth.token

async def _request(
self,
Expand Down Expand Up @@ -116,12 +192,15 @@ async def _request(
self.session = ClientSession()
self._close_session = True

token = await self._get_access_token()
if token is not None:
headers["Authorization"] = f"Bearer {token}"

try:
async with asyncio.timeout(self.request_timeout):
response = await self.session.request(
METH_GET,
url.with_query(data),
auth=self._auth,
headers=headers,
)
response.raise_for_status()
Expand Down Expand Up @@ -185,7 +264,7 @@ async def get_states(

async def get_own_states(self, time: int = 0) -> StatesResponse:
"""Retrieve state vectors from your own sensors."""
if not self._auth:
if self._oauth is None:
raise OpenSkyUnauthenticatedError
params = {
"time": time,
Expand Down
Loading