From e94f45e46197fce9e65459a42a71848bbedc6903 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 12 Jun 2026 15:53:55 +0000 Subject: [PATCH 1/3] Add reusable verified-inference client utilities (OHTTP + signature verification) Expose the client side of the OpenGradient verified-inference protocol so any integrator (the chat-app relay, third-party tools, and the new local proxy) can route an OpenAI-style request to a TEE through an untrusted relay and cryptographically verify the response, sharing one non-drifting implementation. - tee_ohttp.py: client-side Oblivious HTTP (RFC 9458) encapsulation + single-shot and chunked-streaming response decryption, wire-compatible with the tee-gateway recipient and the chat-app browser client. - tee_verify.py: RSA-PSS response signature verification, request canonicalization (build_inner_request), and tee_id/PCR helpers, mirroring the gateway's signing. - tee_ohttp_client.py: high-level OhttpRelayClient tying registry + OHTTP + verify together; verifies before returning (buffers streams) so no unverified token is surfaced. Caller supplies relay auth headers. - tee_registry: surface on-chain pcr_hash on TEEEndpoint for reproducible-build PCR pinning. - Tests round-trip the OHTTP crypto against the real tee-gateway recipient code and verify signatures against an independently-constructed gateway signature. https://claude.ai/code/session_01PdYbDC47zuBGiZex7ZHMSs --- pyproject.toml | 4 + src/opengradient/__init__.py | 35 +- src/opengradient/client/__init__.py | 26 +- src/opengradient/client/tee_ohttp.py | 242 +++++++++++++ src/opengradient/client/tee_ohttp_client.py | 313 ++++++++++++++++ src/opengradient/client/tee_registry.py | 6 + src/opengradient/client/tee_verify.py | 377 ++++++++++++++++++++ tests/tee_ohttp_client_test.py | 156 ++++++++ tests/tee_ohttp_test.py | 91 +++++ tests/tee_verify_test.py | 158 ++++++++ uv.lock | 20 +- 11 files changed, 1423 insertions(+), 5 deletions(-) create mode 100644 src/opengradient/client/tee_ohttp.py create mode 100644 src/opengradient/client/tee_ohttp_client.py create mode 100644 src/opengradient/client/tee_verify.py create mode 100644 tests/tee_ohttp_client_test.py create mode 100644 tests/tee_ohttp_test.py create mode 100644 tests/tee_verify_test.py diff --git a/pyproject.toml b/pyproject.toml index 09ed962..de454d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,10 @@ dependencies = [ "pydantic>=2.9.2", "og-x402>=0.0.2.dev2", "og-x402[extensions]>=0.0.2.dev2", + # Verified-inference utilities: HPKE for client-side Oblivious HTTP, and + # cryptography for RSA-PSS signature verification. + "pyhpke>=0.6.0", + "cryptography>=43.0.0", ] [project.optional-dependencies] diff --git a/src/opengradient/__init__.py b/src/opengradient/__init__.py index 201513a..e035090 100644 --- a/src/opengradient/__init__.py +++ b/src/opengradient/__init__.py @@ -77,7 +77,21 @@ async def stream_example(): """ from . import agents, alphasense -from .client import LLM, Alpha, ModelHub, Twins +from .client import ( + LLM, + Alpha, + ModelHub, + OhttpRelayClient, + RelayError, + TEEEndpoint, + TEERegistry, + TeeProof, + Twins, + VerificationError, + VerifiedChatResponse, + build_inner_request, + verify_response, +) from .types import ( TEE_LLM, CandleOrder, @@ -112,6 +126,16 @@ async def stream_example(): "x402SettlementMode", "agents", "alphasense", + # Verified-inference building blocks + "TEERegistry", + "TEEEndpoint", + "OhttpRelayClient", + "VerifiedChatResponse", + "RelayError", + "TeeProof", + "VerificationError", + "build_inner_request", + "verify_response", ] __pdoc__ = { @@ -140,4 +164,13 @@ async def stream_example(): "CandleType": False, "HistoricalInputQuery": False, "SchedulerParams": False, + "TEERegistry": False, + "TEEEndpoint": False, + "OhttpRelayClient": False, + "VerifiedChatResponse": False, + "RelayError": False, + "TeeProof": False, + "VerificationError": False, + "build_inner_request": False, + "verify_response": False, } diff --git a/src/opengradient/client/__init__.py b/src/opengradient/client/__init__.py index 9ac760b..b5ea1cd 100644 --- a/src/opengradient/client/__init__.py +++ b/src/opengradient/client/__init__.py @@ -33,9 +33,28 @@ from .alpha import Alpha from .llm import LLM from .model_hub import ModelHub +from .tee_ohttp_client import OhttpRelayClient, RelayError, VerifiedChatResponse +from .tee_registry import TEEEndpoint, TEERegistry +from .tee_verify import TeeProof, VerificationError, build_inner_request, verify_response from .twins import Twins -__all__ = ["LLM", "Alpha", "ModelHub", "Twins"] +__all__ = [ + "LLM", + "Alpha", + "ModelHub", + "Twins", + # Verified-inference building blocks: route an OpenAI-style request to a TEE + # through an untrusted relay, then cryptographically verify the response. + "TEERegistry", + "TEEEndpoint", + "OhttpRelayClient", + "VerifiedChatResponse", + "RelayError", + "TeeProof", + "VerificationError", + "build_inner_request", + "verify_response", +] __pdoc__ = { "Alpha": False, @@ -45,5 +64,8 @@ "client": False, "exceptions": False, "opg_token": False, - "tee_registry": False, + "tee_registry": True, + "tee_ohttp": True, + "tee_verify": True, + "tee_ohttp_client": True, } diff --git a/src/opengradient/client/tee_ohttp.py b/src/opengradient/client/tee_ohttp.py new file mode 100644 index 0000000..d39f969 --- /dev/null +++ b/src/opengradient/client/tee_ohttp.py @@ -0,0 +1,242 @@ +"""Client-side Oblivious HTTP (RFC 9458) encapsulation for anonymous TEE inference. + +This is the *sender* side of the construction the tee-gateway implements on the +recipient side and the chat-app implements in the browser. Using it, a client can +HPKE-encrypt an inference request to a TEE's published X25519 key, send it through +an untrusted relay, and decrypt the (single-shot or chunked-streaming) response — +the relay only ever sees ciphertext. + +The ciphersuite is fixed and must match the enclave and the on-chain +`opengradient.client.tee_registry.OhttpConfig`: + + - KEM: DHKEM(X25519, HKDF-SHA256) (0x0020) + - KDF: HKDF-SHA256 (0x0001) + - AEAD: ChaCha20-Poly1305 (0x0003) + +We use `pyhpke` for the HPKE sender context (the same library the gateway uses on +the recipient side, guaranteeing wire compatibility) and derive the response keys +with the same manual HKDF the gateway uses, so responses decrypt byte-for-byte. + +Wire formats: + Request: header(7) || enc(32) || AEAD ciphertext + Response: response_nonce(32) || AEAD ciphertext (single-shot) + Chunked: response_nonce(32) || (varint(len)||sealed)+ || varint(0)||final +""" + +from __future__ import annotations + +import struct +from dataclasses import dataclass + +from cryptography.hazmat.primitives import hashes, hmac +from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 +from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand +from pyhpke import AEADId, CipherSuite, KDFId, KEMId + +# RFC 9180 / 9458 algorithm identifiers (fixed suite). +KEY_CONFIG_ID = 0x01 +KEM_ID_X25519 = 0x0020 +KDF_ID_HKDF_SHA256 = 0x0001 +AEAD_ID_CHACHA20_POLY1305 = 0x0003 + +_NK = 32 # AEAD key length / response_nonce length (== max(Nn, Nk)) +_NN = 12 # AEAD nonce length + +_LABEL_REQUEST = b"message/bhttp request" +_LABEL_RESPONSE = b"message/bhttp response" +_LABEL_CHUNKED_RESPONSE = b"message/bhttp chunked response" + +_SUITE = CipherSuite.new( + KEMId.DHKEM_X25519_HKDF_SHA256, + KDFId.HKDF_SHA256, + AEADId.CHACHA20_POLY1305, +) + + +def _header_bytes() -> bytes: + return bytes([KEY_CONFIG_ID]) + struct.pack(">HHH", KEM_ID_X25519, KDF_ID_HKDF_SHA256, AEAD_ID_CHACHA20_POLY1305) + + +@dataclass +class EncapsulatedRequest: + """An HPKE-sealed request plus the secrets needed to open its response. + + Attributes: + wire: The bytes to send to the relay (header || enc || ciphertext). + enc: Our ephemeral X25519 public key; salts the response keying. + response_secret: Exported secret for a single-shot response. + chunked_response_secret: Exported secret for a chunked-streaming response. + """ + + wire: bytes + enc: bytes + response_secret: bytes + chunked_response_secret: bytes + + +def encapsulate_request(public_key_raw: bytes, plaintext: bytes) -> EncapsulatedRequest: + """HPKE-seal ``plaintext`` to a TEE's raw X25519 public key. + + Args: + public_key_raw: The 32-byte raw X25519 public key from the TEE's OHTTP + config (``OhttpConfig.public_key``). + plaintext: The inner request body (typically a UTF-8 JSON chat request). + + Returns: + An `EncapsulatedRequest` ready to send to a relay. + + Raises: + ValueError: If ``public_key_raw`` is not 32 bytes. + """ + if len(public_key_raw) != 32: + raise ValueError("X25519 public key must be 32 bytes") + + pkr = _SUITE.kem.deserialize_public_key(public_key_raw) + info = _LABEL_REQUEST + b"\x00" + _header_bytes() + enc, sender = _SUITE.create_sender_context(pkr, info=info) + + ciphertext = sender.seal(plaintext, aad=b"") + wire = _header_bytes() + bytes(enc) + ciphertext + + export_len = max(_NN, _NK) + return EncapsulatedRequest( + wire=wire, + enc=bytes(enc), + response_secret=sender.export(_LABEL_RESPONSE, export_len), + chunked_response_secret=sender.export(_LABEL_CHUNKED_RESPONSE, export_len), + ) + + +def _derive_response_keys(response_secret: bytes, enc: bytes, response_nonce: bytes) -> tuple[bytes, bytes]: + """HKDF-Extract(salt=enc||response_nonce, ikm=response_secret) then Expand. + + Byte-identical to the gateway's response-key derivation, so both single-shot + and chunked responses decrypt correctly. + """ + h = hmac.HMAC(enc + response_nonce, hashes.SHA256()) + h.update(response_secret) + prk = h.finalize() + aead_key = HKDFExpand(algorithm=hashes.SHA256(), length=_NK, info=b"key").derive(prk) + aead_nonce = HKDFExpand(algorithm=hashes.SHA256(), length=_NN, info=b"nonce").derive(prk) + return aead_key, aead_nonce + + +def decrypt_response(response_secret: bytes, enc: bytes, sealed: bytes) -> bytes: + """Decrypt a single-shot OHTTP response (RFC 9458 §4.5). + + Args: + response_secret: ``EncapsulatedRequest.response_secret``. + enc: ``EncapsulatedRequest.enc``. + sealed: The full response body from the relay. + + Returns: + The decrypted inner response bytes. + + Raises: + ValueError: If the response is too short to be well-formed. + """ + if len(sealed) <= _NK: + raise ValueError("malformed OHTTP response") + response_nonce = sealed[:_NK] + ciphertext = sealed[_NK:] + aead_key, aead_nonce = _derive_response_keys(response_secret, enc, response_nonce) + return ChaCha20Poly1305(aead_key).decrypt(aead_nonce, ciphertext, b"") + + +def _decode_varint(buf: bytes, offset: int) -> tuple[int, int] | None: + """Parse one QUIC varint; returns ``(value, new_offset)`` or ``None`` if more bytes are needed.""" + if offset >= len(buf): + return None + first = buf[offset] + length = 1 << (first >> 6) + if offset + length > len(buf): + return None + value = first & 0x3F + for i in range(1, length): + value = (value << 8) | buf[offset + i] + return value, offset + length + + +class ChunkedResponseDecrypter: + """Incrementally decrypt a chunked OHTTP response stream (draft-ietf-ohai-chunked-ohttp-08). + + Feed it raw response bytes as they arrive; it yields decrypted plaintext + frames (typically the inner SSE ``data:`` events). The final frame carries + AAD=b"final"; its absence at end-of-stream is treated as truncation, so a + network attacker cannot silently cut a stream short. + """ + + def __init__(self, response_secret: bytes, enc: bytes): + self._response_secret = response_secret + self._enc = enc + self._buffer = bytearray() + self._key: bytes | None = None + self._nonce: bytes | None = None + self._counter = 0 + self._saw_final = False + + def push(self, chunk: bytes | None, done: bool) -> list[bytes]: + """Feed bytes and return any newly-decrypted plaintext frames. + + Args: + chunk: Newly-received bytes (or ``None``). + done: Whether the underlying stream has ended. + + Returns: + A list of decrypted plaintext frames (possibly empty). + + Raises: + ValueError: On a malformed or truncated stream. + """ + if chunk: + self._buffer.extend(chunk) + + if self._key is None or self._nonce is None: + if len(self._buffer) < _NK: + if done: + raise ValueError("malformed chunked OHTTP response") + return [] + response_nonce = bytes(self._buffer[:_NK]) + self._key, self._nonce = _derive_response_keys(self._response_secret, self._enc, response_nonce) + del self._buffer[:_NK] + + out: list[bytes] = [] + while self._buffer: + frame = _decode_varint(self._buffer, 0) + if frame is None: + if done: + raise ValueError("malformed chunked OHTTP response") + break + sealed_len, offset = frame + + if sealed_len == 0: + # Zero-length prefix marks the final chunk; AAD=b"final". + if not done: + break + ciphertext = bytes(self._buffer[offset:]) + out.append(self._decrypt_chunk(ciphertext, is_final=True)) + self._buffer.clear() + self._saw_final = True + break + + if len(self._buffer) < offset + sealed_len: + if done: + raise ValueError("truncated chunked OHTTP response") + break + + ciphertext = bytes(self._buffer[offset : offset + sealed_len]) + out.append(self._decrypt_chunk(ciphertext, is_final=False)) + del self._buffer[: offset + sealed_len] + + if done and not self._saw_final: + raise ValueError("chunked OHTTP response missing final marker") + return out + + def _decrypt_chunk(self, ciphertext: bytes, is_final: bool) -> bytes: + assert self._key is not None and self._nonce is not None + ctr = self._counter.to_bytes(_NN, "big") + chunk_nonce = bytes(a ^ b for a, b in zip(self._nonce, ctr)) + aad = b"final" if is_final else b"" + plaintext = ChaCha20Poly1305(self._key).decrypt(chunk_nonce, ciphertext, aad) + self._counter += 1 + return plaintext diff --git a/src/opengradient/client/tee_ohttp_client.py b/src/opengradient/client/tee_ohttp_client.py new file mode 100644 index 0000000..e5593eb --- /dev/null +++ b/src/opengradient/client/tee_ohttp_client.py @@ -0,0 +1,313 @@ +"""High-level Oblivious HTTP relay client for verified, private TEE inference. + +This ties together the three lower-level pieces so an integrator doesn't have to: + + 1. `opengradient.client.tee_registry` — discover a TEE (endpoint, OHTTP key, + signing key) from the on-chain registry. + 2. `opengradient.client.tee_ohttp` — HPKE-encrypt the request and decrypt the + response. + 3. `opengradient.client.tee_verify` — verify the enclave's RSA-PSS signature. + +The relay (which holds the x402 wallet / account credentials and pays per +request) only ever sees ciphertext. Authentication to the relay is left to the +caller: pass an ``auth_headers`` provider returning whatever the relay expects +(e.g. ``{"Authorization": "Bearer "}``), so this client works for any +relay deployment without baking in a credential scheme. + +Verification happens **before** any content is returned. For streaming requests +the full encrypted stream is buffered, verified, and only then handed back as +decrypted SSE frames — so a caller can guarantee no unverified token ever +reaches the end user, at the cost of streaming latency. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Callable, Iterable, Optional + +import requests + +from .tee_ohttp import ChunkedResponseDecrypter, encapsulate_request +from .tee_registry import TEEEndpoint +from .tee_verify import ( + TeeProof, + VerificationError, + build_inner_request, + pem_from_der, + response_content_for_hash, + verify_response, +) + +OHTTP_REQUEST_MEDIA_TYPE = "message/ohttp-req" +OHTTP_RESPONSE_MEDIA_TYPE = "message/ohttp-res" +OHTTP_CHUNKED_RESPONSE_MEDIA_TYPE = "message/ohttp-chunked-res" + +AuthHeaderProvider = Callable[[], dict] + + +class RelayError(Exception): + """The relay returned a non-success status, or the inner response was an error. + + Attributes: + status_code: The HTTP (or inner) status code. + message: A human-readable error message extracted from the response. + """ + + def __init__(self, status_code: int, message: str): + super().__init__(f"relay error {status_code}: {message}") + self.status_code = status_code + self.message = message + + +@dataclass +class VerifiedChatResponse: + """A TEE chat response that has passed signature verification. + + Attributes: + body: The inner response JSON (the single-shot body, or the final SSE + frame for a stream). + content: The assistant text (or tool-calls JSON) that was verified. + proof: The :class:`opengradient.client.tee_verify.TeeProof`. + stream_frames: For streaming requests, the decrypted inner SSE ``data:`` + event strings (already verified), ready to replay to a client; + ``None`` for single-shot requests. + """ + + body: dict + content: str + proof: TeeProof + stream_frames: Optional[list[str]] = None + + +class OhttpRelayClient: + """Send verified, private chat completions to a TEE through an OHTTP relay. + + Args: + relay_url: Full URL to POST encapsulated requests to (e.g. + ``https://chat-api.example.com/api/v1/chat/ohttp``). + tee: The :class:`opengradient.client.tee_registry.TEEEndpoint` to encrypt + to (must carry an ``ohttp_config`` and ``signing_public_key_der``). + auth_headers: Optional callable returning headers to authenticate to the + relay (called per request so tokens can be refreshed). + session: Optional ``requests.Session`` to reuse connections. + timeout: Per-request timeout in seconds. + """ + + def __init__( + self, + relay_url: str, + tee: TEEEndpoint, + *, + auth_headers: Optional[AuthHeaderProvider] = None, + session: Optional[requests.Session] = None, + timeout: float = 120.0, + ): + if tee.ohttp_config is None or len(tee.ohttp_config.public_key) != 32: + raise ValueError("TEEEndpoint has no usable OHTTP config") + if not tee.signing_public_key_der: + raise ValueError("TEEEndpoint is missing a signing public key") + self._relay_url = relay_url + self._tee = tee + self._ohttp_public_key = tee.ohttp_config.public_key + self._auth_headers = auth_headers + self._session = session or requests.Session() + self._timeout = timeout + self._signing_key_pem = pem_from_der(tee.signing_public_key_der) + + def chat_completion(self, body: dict) -> VerifiedChatResponse: + """Send a non-streaming chat completion and return a verified response. + + Args: + body: An OpenAI ``/v1/chat/completions`` request body. + + Returns: + A :class:`VerifiedChatResponse`. + + Raises: + RelayError: If the relay or the inner request errored. + VerificationError: If the response signature could not be verified. + opengradient.client.tee_verify.UnsupportedRequestError: If the body is invalid. + """ + wire, canonical = build_inner_request(body) + enc = encapsulate_request(self._ohttp_public_key, json.dumps(wire).encode("utf-8")) + + resp = self._session.post( + self._relay_url, + data=enc.wire, + headers=self._headers(stream=False), + timeout=self._timeout, + ) + if not resp.ok: + raise RelayError(resp.status_code, _error_message(resp.content)) + + from .tee_ohttp import decrypt_response + + inner_bytes = decrypt_response(enc.response_secret, enc.enc, resp.content) + status, inner = _normalize_inner(json.loads(inner_bytes.decode("utf-8"))) + if status >= 400: + raise RelayError(status, str(inner.get("error", "TEE inner error"))) + + content = response_content_for_hash(inner) + proof = verify_response( + canonical_request=canonical, + response_body=inner, + response_content=content, + signing_key_pem=self._signing_key_pem, + expected_tee_id=self._tee.tee_id, + tee_host=self._tee.endpoint, + ) + return VerifiedChatResponse(body=inner, content=content, proof=proof) + + def stream_chat_completion(self, body: dict) -> VerifiedChatResponse: + """Send a streaming chat completion, verify it, then return decrypted frames. + + The encrypted stream is fully buffered and verified before returning, so + the returned ``stream_frames`` are safe to replay to an end user. (This + trades streaming latency for the "no unverified token leaves the machine" + guarantee.) + + Args: + body: An OpenAI ``/v1/chat/completions`` request body (``stream`` is + forced on for the wire request). + + Returns: + A :class:`VerifiedChatResponse` with ``stream_frames`` populated. + + Raises: + RelayError, VerificationError, UnsupportedRequestError: As for + :meth:`chat_completion`. + """ + wire, canonical = build_inner_request(body) + wire = {**wire, "stream": True} + enc = encapsulate_request(self._ohttp_public_key, json.dumps(wire).encode("utf-8")) + + resp = self._session.post( + self._relay_url, + data=enc.wire, + headers=self._headers(stream=True), + timeout=self._timeout, + stream=True, + ) + if not resp.ok: + raise RelayError(resp.status_code, _error_message(resp.content)) + + decrypter = ChunkedResponseDecrypter(enc.chunked_response_secret, enc.enc) + frames: list[str] = [] + full_content = "" + final_frame: Optional[dict] = None + + chunks = resp.iter_content(chunk_size=8192) + for raw, is_last in _with_last(chunks): + for plaintext in decrypter.push(raw, done=is_last): + text = plaintext.decode("utf-8", errors="replace") + frames.append(text) + delta, final = _parse_sse_frame(text) + full_content += delta + if final is not None: + final_frame = final + + if final_frame is None: + raise VerificationError("TEE stream missing a signed final frame") + + proof = verify_response( + canonical_request=canonical, + response_body=final_frame, + response_content=full_content, + signing_key_pem=self._signing_key_pem, + expected_tee_id=self._tee.tee_id, + tee_host=self._tee.endpoint, + ) + return VerifiedChatResponse( + body=final_frame, content=full_content, proof=proof, stream_frames=frames + ) + + def _headers(self, *, stream: bool) -> dict: + headers = { + "Content-Type": OHTTP_REQUEST_MEDIA_TYPE, + "Accept": OHTTP_CHUNKED_RESPONSE_MEDIA_TYPE if stream else OHTTP_RESPONSE_MEDIA_TYPE, + "X-TEE-ID": self._tee.tee_id, + } + if stream: + headers["X-OHTTP-Stream"] = "true" + if self._auth_headers: + headers.update(self._auth_headers()) + return headers + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _normalize_inner(decoded) -> tuple[int, dict]: + """Accept both ``{status, body}`` envelopes and a bare response object.""" + if isinstance(decoded, dict) and isinstance(decoded.get("status"), int) and isinstance(decoded.get("body"), dict): + return decoded["status"], decoded["body"] + if isinstance(decoded, dict): + return 200, decoded + raise VerificationError("malformed inner response") + + +def _parse_sse_frame(text: str) -> tuple[str, Optional[dict]]: + """Parse one decrypted SSE frame; return ``(delta_text, final_frame_or_None)``.""" + delta = "" + final: Optional[dict] = None + for line in text.splitlines(): + line = line.strip() + if not line.startswith("data:"): + continue + payload = line[len("data:") :].strip() + if not payload or payload == "[DONE]": + continue + try: + parsed = json.loads(payload) + except json.JSONDecodeError: + continue + if not isinstance(parsed, dict): + continue + if isinstance(parsed.get("error"), str): + raise RelayError(502, parsed["error"]) + delta += _delta_content(parsed) + if isinstance(parsed.get("tee_signature"), str) or isinstance(parsed.get("tee_output_hash"), str): + final = parsed + return delta, final + + +def _delta_content(frame: dict) -> str: + choices = frame.get("choices") + if not isinstance(choices, list) or not choices or not isinstance(choices[0], dict): + return "" + delta = choices[0].get("delta") + if not isinstance(delta, dict): + return "" + content = delta.get("content") + if isinstance(content, list): + return "".join(p.get("text", "") if isinstance(p, dict) else str(p) for p in content) + return content if isinstance(content, str) else "" + + +def _with_last(iterable: Iterable[bytes]): + """Yield ``(item, is_last)`` pairs, with ``is_last`` True only for the final item.""" + iterator = iter(iterable) + try: + prev = next(iterator) + except StopIteration: + # Empty stream: signal one final, empty push so the decrypter can report + # the missing-final-marker truncation error rather than hanging. + yield b"", True + return + for item in iterator: + yield prev, False + prev = item + yield prev, True + + +def _error_message(content: bytes) -> str: + try: + body = json.loads(content.decode("utf-8")) + if isinstance(body, dict): + return str(body.get("detail") or body.get("error") or "relay error") + except (UnicodeDecodeError, json.JSONDecodeError): + pass + return content.decode("utf-8", errors="replace")[:500] or "relay error" diff --git a/src/opengradient/client/tee_registry.py b/src/opengradient/client/tee_registry.py index de7be8e..24288c0 100644 --- a/src/opengradient/client/tee_registry.py +++ b/src/opengradient/client/tee_registry.py @@ -81,6 +81,10 @@ class TEEEndpoint: payment_address: x402 settlement address for this TEE. signing_public_key_der: DER (SPKI) RSA public key the TEE signs with. ohttp_config: The TEE's OHTTP/HPKE key configuration, if present. + pcr_hash: The reproducible-build PCR measurement hash recorded on-chain + (``0x``-prefixed hex). Lets a caller refuse any TEE whose code + fingerprint differs from a known-good build — trusting math over the + registry operator. """ tee_id: str @@ -89,6 +93,7 @@ class TEEEndpoint: payment_address: str signing_public_key_der: bytes = b"" ohttp_config: Optional[OhttpConfig] = None + pcr_hash: str = "" class TEERegistry: @@ -154,6 +159,7 @@ def get_active_tees_by_type(self, tee_type: int) -> List[TEEEndpoint]: payment_address=tee.payment_address, signing_public_key_der=bytes(tee.public_key), ohttp_config=_parse_ohttp_config(tee.ohttp_config), + pcr_hash="0x" + bytes(tee.pcr_hash).hex(), ) ) diff --git a/src/opengradient/client/tee_verify.py b/src/opengradient/client/tee_verify.py new file mode 100644 index 0000000..da38c17 --- /dev/null +++ b/src/opengradient/client/tee_verify.py @@ -0,0 +1,377 @@ +"""Cryptographic verification of TEE inference responses. + +Every response from an OpenGradient TEE gateway carries an RSA-PSS signature over +``keccak256(requestHash || outputHash || uint256(timestamp))``, produced inside +the enclave with the key the on-chain registry records for that TEE. The trust +chain is: + + reproducible build -> PCRs -> on-chain registry entry (pcrHash + signing key) + -> per-response RSA-PSS signature + +So if you trust a TEE's registry signing key (optionally pinned to an expected +``pcrHash``) and the signature verifies, the response was produced inside that +attested enclave and was not modified in transit — no trust in the relay, the +host, or us is required. The relay never holds the signing key and so cannot +forge a response. + +This module mirrors the gateway's signing (``compute_tee_msg_hash`` + RSA-PSS, +salt length 32) and the chat-app's browser verification, kept in one place so all +clients verify identically. +""" + +from __future__ import annotations + +import base64 +import copy +import json +from dataclasses import dataclass +from typing import Any, Optional + +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey +from eth_hash.auto import keccak + + +class VerificationError(Exception): + """Raised when a response fails any step of TEE verification. + + Callers should treat this as fatal: never surface content that failed + verification to the end user. + """ + + +class UnsupportedRequestError(Exception): + """Raised when an OpenAI-style request cannot be expressed as a gateway request.""" + + +@dataclass +class TeeProof: + """The verified provenance of a single response. + + Attributes: + tee_id: The TEE identity (``0x`` + keccak256 of the signing key DER). + request_hash: keccak256 of the canonical request, hex (no ``0x``). + output_hash: keccak256 of the signed output content, hex (no ``0x``). + timestamp: The enclave-asserted signing timestamp (unix seconds). + signature: The base64 RSA-PSS signature that was verified. + signing_key_pem: The PEM signing key the signature verified against. + tee_host: Optional host the response came from, for display. + """ + + tee_id: str + request_hash: str + output_hash: str + timestamp: int + signature: str + signing_key_pem: str + tee_host: Optional[str] = None + + +# --------------------------------------------------------------------------- +# Signing-key helpers +# --------------------------------------------------------------------------- + + +def pem_from_der(signing_public_key_der: bytes) -> str: + """Convert a DER (SPKI) public key (e.g. ``TEEEndpoint.signing_public_key_der``) to PEM.""" + key = serialization.load_der_public_key(bytes(signing_public_key_der)) + return key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ).decode("utf-8") + + +def tee_id_for_key(signing_key_pem: str) -> str: + """Return ``0x`` + keccak256(DER(SubjectPublicKeyInfo)). + + Matches the gateway's ``TEEKeyManager.tee_id`` and the registry's keyed tee_id. + """ + key = serialization.load_pem_public_key(signing_key_pem.encode("utf-8")) + der = key.public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return "0x" + keccak(der).hex() + + +# --------------------------------------------------------------------------- +# Request canonicalization (must byte-match the gateway's request hashing) +# --------------------------------------------------------------------------- + + +def canonical_user_content(content: Any) -> Any: + """Canonicalize user-message content for request hashing. + + Plain strings pass through. For multimodal content (a list of parts), text is + kept verbatim and every attachment is reduced to ``{type[, filename]}`` — the + inline bytes are dropped (they ride inside the encrypted envelope and are never + hashed). Mirrors the gateway's ``canonical_user_content``. + """ + if isinstance(content, str): + return content + if not isinstance(content, list): + return str(content) + + canonical: list[Any] = [] + for part in content: + if not isinstance(part, dict): + canonical.append({"type": "text", "text": str(part)}) + continue + if part.get("type") == "text": + canonical.append({"type": "text", "text": part.get("text", "") or ""}) + continue + entry: dict[str, Any] = {"type": part.get("type")} + file_obj = part.get("file") + filename = (file_obj.get("filename") if isinstance(file_obj, dict) else None) or part.get("filename") + if filename: + entry["filename"] = filename + canonical.append(entry) + return canonical + + +def _canonical_message(msg: dict) -> dict: + """Shape one message exactly as the gateway does before hashing.""" + role = msg.get("role") + if role in ("system", "developer"): # gateway treats developer as system + return {"role": "system", "content": msg.get("content")} + if role == "user": + return {"role": "user", "content": canonical_user_content(msg.get("content"))} + if role == "assistant": + out: dict[str, Any] = {"role": "assistant", "content": msg.get("content") or ""} + tool_calls = msg.get("tool_calls") + if tool_calls: + out["tool_calls"] = [ + { + "id": tc.get("id", ""), + "type": tc.get("type", "function"), + "function": { + "name": tc.get("function", {}).get("name", ""), + "arguments": tc.get("function", {}).get("arguments", ""), + }, + } + for tc in tool_calls + ] + return out + if role == "tool": + return {"role": "tool", "content": msg.get("content"), "tool_call_id": msg.get("tool_call_id")} + if role == "function": + return {"role": "function", "content": msg.get("content"), "name": msg.get("name")} + raise UnsupportedRequestError(f"unknown message role: {role!r}") + + +def build_inner_request(body: dict) -> tuple[dict, dict]: + """Build ``(wire, canonical)`` request dicts from an OpenAI chat-completions ``body``. + + The gateway commits (in its signed request hash) to only a fixed subset of + fields, so anything else the caller sends (``n``, ``top_p``, ``tool_choice``, + ...) is intentionally dropped — it would not be covered by the signature. + + Args: + body: An OpenAI ``/v1/chat/completions`` request body. + + Returns: + ``(wire, canonical)`` where ``wire`` is the object to encrypt and send + (full message content preserved so the model sees attachment bytes), and + ``canonical`` is the dict whose ``json.dumps(sort_keys=True)`` the gateway + hashes (attachment bytes stripped). Pass ``canonical`` to + :func:`verify_response`. + + Raises: + UnsupportedRequestError: If the body is missing ``model`` or ``messages`` or + contains an unknown message role. + """ + if not isinstance(body, dict): + raise UnsupportedRequestError("request body must be a JSON object") + model = body.get("model") + if not isinstance(model, str) or not model: + raise UnsupportedRequestError("request is missing a 'model'") + messages = body.get("messages") + if not isinstance(messages, list) or not messages: + raise UnsupportedRequestError("request is missing 'messages'") + + canonical_messages = [_canonical_message(m) for m in messages] + + temperature = body.get("temperature") + canonical: dict[str, Any] = { + "model": model, + "messages": canonical_messages, + "temperature": float(temperature) if temperature is not None else 0.0, + } + if body.get("max_tokens") is not None: + canonical["max_tokens"] = body["max_tokens"] + if body.get("stop"): + canonical["stop"] = body["stop"] + if body.get("tools"): + tools = body["tools"] + canonical["tools"] = tools if isinstance(tools, list) else list(tools) + if body.get("response_format"): + canonical["response_format"] = body["response_format"] + if body.get("web_search"): + canonical["web_search"] = True + + wire = copy.deepcopy(canonical) + wire["messages"] = [ + ({**cm, "content": orig.get("content")} if cm.get("role") == "user" else cm) + for cm, orig in zip(wire["messages"], messages) + ] + return wire, canonical + + +def canonical_request_bytes(canonical_request: dict) -> bytes: + """Serialize a canonical request exactly as the gateway hashes it: ``json.dumps(sort_keys=True)``.""" + return json.dumps(canonical_request, sort_keys=True).encode("utf-8") + + +# --------------------------------------------------------------------------- +# Response content + signature verification +# --------------------------------------------------------------------------- + + +def response_content_for_hash(response_body: dict) -> str: + """Extract the exact string the gateway hashed as the signed output. + + For tool-call responses the gateway hashes ``json.dumps(tool_calls, sort_keys=True)``; + otherwise it hashes the assistant message text (generated image bytes are + excluded — they ride out-of-band and are not signed). + """ + choice = _first_choice(response_body) + message = choice.get("message") if isinstance(choice, dict) else None + if isinstance(message, dict): + if choice.get("finish_reason") == "tool_calls" and isinstance(message.get("tool_calls"), list): + return json.dumps(message["tool_calls"], sort_keys=True) + return _content_text(message.get("content")) + return "" + + +def verify_response( + *, + canonical_request: dict, + response_body: dict, + response_content: str, + signing_key_pem: str, + expected_tee_id: Optional[str] = None, + tee_host: Optional[str] = None, +) -> TeeProof: + """Verify a (decrypted) TEE gateway response. + + Args: + canonical_request: The canonical request dict (see :func:`build_inner_request`) + whose ``json.dumps(sort_keys=True)`` the gateway hashed. + response_body: The parsed inner JSON (single-shot body, or the final SSE + frame for streams), carrying ``tee_signature``, ``tee_request_hash``, + ``tee_output_hash``, ``tee_timestamp`` and ``tee_id``. + response_content: The exact text/JSON the gateway hashed as output — use + :func:`response_content_for_hash`, or the accumulated stream text. + signing_key_pem: The enclave's RSA public key from the on-chain registry + (the trust anchor; convert DER via :func:`pem_from_der`). + expected_tee_id: If given, require the response/key tee_id to match. + tee_host: Optional host, recorded on the returned proof for display. + + Returns: + A :class:`TeeProof` describing the verified provenance. + + Raises: + VerificationError: If any check fails (missing fields, tee_id mismatch, + request/output hash mismatch, or bad signature). + """ + signature_b64 = _require_str(response_body, "tee_signature") + reported_request_hash = _require_str(response_body, "tee_request_hash") + reported_output_hash = _require_str(response_body, "tee_output_hash") + timestamp = _require_int(response_body, "tee_timestamp") + reported_tee_id = _require_str(response_body, "tee_id") + + # 1. The signing key must key the tee_id the response claims, so a signature + # from enclave A cannot be replayed as enclave B's. + key_tee_id = tee_id_for_key(signing_key_pem) + if _strip0x(reported_tee_id).lower() != _strip0x(key_tee_id).lower(): + raise VerificationError(f"tee_id mismatch: response says {reported_tee_id}, signing key is {key_tee_id}") + if expected_tee_id and _strip0x(reported_tee_id).lower() != _strip0x(expected_tee_id).lower(): + raise VerificationError(f"tee_id mismatch: response says {reported_tee_id}, expected {expected_tee_id}") + + # 2. Recompute the request hash from what we actually sent. + computed_request_hash = keccak(canonical_request_bytes(canonical_request)).hex() + if computed_request_hash != _strip0x(reported_request_hash): + raise VerificationError( + "request hash mismatch: the gateway signed a different request than we sent " + f"(computed {computed_request_hash}, signed {reported_request_hash})" + ) + + # 3. Recompute the output hash from the content we're about to return. + computed_output_hash = keccak(response_content.encode("utf-8")).hex() + if computed_output_hash != _strip0x(reported_output_hash): + raise VerificationError( + "output hash mismatch: the response content does not match the signed output " + f"(computed {computed_output_hash}, signed {reported_output_hash})" + ) + + # 4. Rebuild the signed message hash and verify the RSA-PSS signature. + # msg_hash = keccak256(inputHash || outputHash || uint256(timestamp)) + input_hash = bytes.fromhex(computed_request_hash) + output_hash = bytes.fromhex(computed_output_hash) + msg_hash = keccak(input_hash + output_hash + timestamp.to_bytes(32, "big")) + + key = serialization.load_pem_public_key(signing_key_pem.encode("utf-8")) + if not isinstance(key, RSAPublicKey): + raise VerificationError("signing key is not an RSA public key") + try: + key.verify( + base64.b64decode(signature_b64), + msg_hash, + padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=32), + hashes.SHA256(), + ) + except InvalidSignature as exc: + raise VerificationError("RSA-PSS signature verification failed") from exc + + return TeeProof( + tee_id=reported_tee_id, + request_hash=computed_request_hash, + output_hash=computed_output_hash, + timestamp=timestamp, + signature=signature_b64, + signing_key_pem=signing_key_pem, + tee_host=tee_host, + ) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _first_choice(body: dict) -> dict: + choices = body.get("choices") + if isinstance(choices, list) and choices and isinstance(choices[0], dict): + return choices[0] + return {} + + +def _content_text(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + return "".join(part.get("text", "") if isinstance(part, dict) else str(part) for part in content) + return "" + + +def _strip0x(value: str) -> str: + return value[2:] if value.startswith("0x") else value + + +def _require_str(body: dict, key: str) -> str: + value = body.get(key) + if not isinstance(value, str): + raise VerificationError(f"response is missing a string '{key}' — cannot verify") + return value + + +def _require_int(body: dict, key: str) -> int: + value = body.get(key) + if isinstance(value, bool) or not isinstance(value, (int, str)): + raise VerificationError(f"response is missing an integer '{key}' — cannot verify") + try: + return int(value) + except ValueError as exc: + raise VerificationError(f"'{key}' is not an integer: {value!r}") from exc diff --git a/tests/tee_ohttp_client_test.py b/tests/tee_ohttp_client_test.py new file mode 100644 index 0000000..1e1e6e3 --- /dev/null +++ b/tests/tee_ohttp_client_test.py @@ -0,0 +1,156 @@ +"""End-to-end test for the high-level OhttpRelayClient. + +Simulates the full path — client encrypts -> (fake relay+gateway) decapsulates, +signs, seals -> client decrypts + verifies — for both single-shot and streaming, +using the real tee-gateway recipient crypto. Skips if no tee-gateway checkout. +""" + +from __future__ import annotations + +import base64 +import json +import os +import sys +from pathlib import Path + +import pytest +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from eth_hash.auto import keccak + +from opengradient.client.tee_ohttp_client import OhttpRelayClient +from opengradient.client.tee_registry import OhttpConfig, TEEEndpoint +from opengradient.client.tee_verify import build_inner_request, tee_id_for_key + + +def _load_server_ohttp(): + override = os.getenv("OG_TEE_GATEWAY") + candidates = [Path(override)] if override else [] + candidates.append(Path(__file__).resolve().parents[2] / "tee-gateway") + for root in candidates: + if (root / "tee_gateway" / "ohttp.py").exists(): + sys.path.insert(0, str(root)) + import tee_gateway.ohttp as srv + + return srv + return None + + +@pytest.fixture(scope="module") +def srv(): + s = _load_server_ohttp() + if s is None: + pytest.skip("tee-gateway checkout not found (set OG_TEE_GATEWAY)") + return s + + +def _sign_fields(priv, canonical, output_content, ts): + input_hash = keccak(json.dumps(canonical, sort_keys=True).encode()) + output_hash = keccak(output_content.encode()) + msg_hash = keccak(input_hash + output_hash + ts.to_bytes(32, "big")) + sig = priv.sign(msg_hash, padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=32), hashes.SHA256()) + return { + "tee_signature": base64.b64encode(sig).decode(), + "tee_request_hash": input_hash.hex(), + "tee_output_hash": output_hash.hex(), + "tee_timestamp": ts, + } + + +class _FakeResp: + def __init__(self, content=b"", chunks=None): + self.ok = True + self.status_code = 200 + self.content = content + self._chunks = chunks or [] + + def iter_content(self, chunk_size=8192): + yield from self._chunks + + +def _make_endpoint(srv): + hpke_priv, hpke_pub = srv.generate_keypair() + rsa_priv = rsa.generate_private_key(public_exponent=65537, key_size=2048) + der = rsa_priv.public_key().public_bytes( + encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + pem = rsa_priv.public_key().public_bytes( + encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo + ).decode() + endpoint = TEEEndpoint( + tee_id=tee_id_for_key(pem), + endpoint="https://gw.example", + tls_cert_der=b"", + payment_address="0x" + "11" * 20, + signing_public_key_der=der, + ohttp_config=OhttpConfig(1, 0x0020, 0x0001, 0x0003, hpke_pub, b"kc", 0), + ) + return endpoint, hpke_priv, rsa_priv + + +def test_single_shot_roundtrip_and_verify(srv): + endpoint, hpke_priv, rsa_priv = _make_endpoint(srv) + body = {"model": "gpt-4.1", "messages": [{"role": "user", "content": "hi"}]} + _wire, canonical = build_inner_request(body) + content = "hello from the enclave" + + def fake_post(url, data=None, headers=None, timeout=None, stream=False): + assert headers["X-TEE-ID"] == endpoint.tee_id + assert headers["Authorization"] == "Bearer t0ken" + decap = srv.decapsulate_request(hpke_priv, data) + assert json.loads(decap.plaintext.decode())["model"] == "gpt-4.1" + resp = { + "choices": [{"index": 0, "message": {"role": "assistant", "content": content}, "finish_reason": "stop"}], + **_sign_fields(rsa_priv, canonical, content, 1_700_000_000), + "tee_id": endpoint.tee_id, + } + sealed = srv.encapsulate_response(decap.response_key, decap.enc, json.dumps(resp).encode()) + return _FakeResp(content=sealed) + + class _Sess: + post = staticmethod(fake_post) + + client = OhttpRelayClient( + "https://relay/api/v1/chat/ohttp", + endpoint, + auth_headers=lambda: {"Authorization": "Bearer t0ken"}, + session=_Sess(), + ) + result = client.chat_completion(body) + assert result.content == content + assert result.proof.tee_id == endpoint.tee_id + assert result.proof.timestamp == 1_700_000_000 + + +def test_streaming_roundtrip_and_verify(srv): + endpoint, hpke_priv, rsa_priv = _make_endpoint(srv) + body = {"model": "gpt-4.1", "messages": [{"role": "user", "content": "hi"}]} + _wire, canonical = build_inner_request(body) + full = "Hello world" + + def fake_post(url, data=None, headers=None, timeout=None, stream=False): + assert stream is True + assert headers["X-OHTTP-Stream"] == "true" + decap = srv.decapsulate_request(hpke_priv, data) + encr = srv.ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) + wire = encr.header() + wire += encr.encrypt_chunk(b'data: {"choices":[{"delta":{"content":"Hello "},"index":0}]}\n\n', is_final=False) + wire += encr.encrypt_chunk(b'data: {"choices":[{"delta":{"content":"world"},"index":0}]}\n\n', is_final=False) + final = { + "choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}], + **_sign_fields(rsa_priv, canonical, full, 1_700_000_001), + "tee_id": endpoint.tee_id, + } + wire += encr.encrypt_chunk(f"data: {json.dumps(final)}\n\n".encode(), is_final=True) + # Deliver as a couple of network chunks to exercise buffering. + mid = len(wire) // 2 + return _FakeResp(chunks=[wire[:mid], wire[mid:]]) + + class _Sess: + post = staticmethod(fake_post) + + client = OhttpRelayClient("https://relay/api/v1/chat/ohttp", endpoint, session=_Sess()) + result = client.stream_chat_completion(body) + assert result.content == full + assert result.proof.timestamp == 1_700_000_001 + assert result.stream_frames is not None and len(result.stream_frames) == 3 diff --git a/tests/tee_ohttp_test.py b/tests/tee_ohttp_test.py new file mode 100644 index 0000000..87ec923 --- /dev/null +++ b/tests/tee_ohttp_test.py @@ -0,0 +1,91 @@ +"""Wire-compatibility round-trips for the client OHTTP encapsulation. + +When a tee-gateway checkout is available (``OG_TEE_GATEWAY`` env var, or a sibling +``../tee-gateway``), these round-trip our client crypto against the *actual* +server recipient code, guaranteeing the two stay byte-compatible. The gateway's +``tee_gateway/ohttp.py`` only needs ``cryptography`` + ``pyhpke``, so importing it +standalone is cheap. The tests skip cleanly when no checkout is found. +""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path + +import pytest + +from opengradient.client import tee_ohttp as cli + + +def _load_server_ohttp(): + override = os.getenv("OG_TEE_GATEWAY") + candidates = [Path(override)] if override else [] + candidates.append(Path(__file__).resolve().parents[2] / "tee-gateway") + for root in candidates: + if (root / "tee_gateway" / "ohttp.py").exists(): + sys.path.insert(0, str(root)) + import tee_gateway.ohttp as srv + + return srv + return None + + +@pytest.fixture(scope="module") +def server_ohttp(): + srv = _load_server_ohttp() + if srv is None: + pytest.skip("tee-gateway checkout not found (set OG_TEE_GATEWAY)") + return srv + + +def test_request_and_single_shot_response(server_ohttp): + priv, pub_raw = server_ohttp.generate_keypair() + plaintext = b'{"model":"gpt-4.1","messages":[{"role":"user","content":"hi"}]}' + + enc_req = cli.encapsulate_request(pub_raw, plaintext) + decap = server_ohttp.decapsulate_request(priv, enc_req.wire) + assert decap.plaintext == plaintext + assert decap.enc == enc_req.enc + + resp_pt = b'{"choices":[{"message":{"content":"hello"}}]}' + sealed = server_ohttp.encapsulate_response(decap.response_key, decap.enc, resp_pt) + assert cli.decrypt_response(enc_req.response_secret, enc_req.enc, sealed) == resp_pt + + +def test_chunked_response_whole_and_incremental(server_ohttp): + priv, pub_raw = server_ohttp.generate_keypair() + enc_req = cli.encapsulate_request(pub_raw, b"{}") + decap = server_ohttp.decapsulate_request(priv, enc_req.wire) + + encr = server_ohttp.ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) + wire = encr.header() + frames = [b"data: a\n\n", b"data: b\n\n", b"data: [DONE]\n\n"] + wire += encr.encrypt_chunk(frames[0], is_final=False) + wire += encr.encrypt_chunk(frames[1], is_final=False) + wire += encr.encrypt_chunk(frames[2], is_final=True) + + dec = cli.ChunkedResponseDecrypter(enc_req.chunked_response_secret, enc_req.enc) + assert dec.push(wire, done=True) == frames + + dec2 = cli.ChunkedResponseDecrypter(enc_req.chunked_response_secret, enc_req.enc) + out: list[bytes] = [] + for i, b in enumerate(wire): + out += dec2.push(bytes([b]), done=(i == len(wire) - 1)) + assert out == frames + + +def test_truncated_chunked_stream_is_rejected(server_ohttp): + priv, pub_raw = server_ohttp.generate_keypair() + enc_req = cli.encapsulate_request(pub_raw, b"{}") + decap = server_ohttp.decapsulate_request(priv, enc_req.wire) + encr = server_ohttp.ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) + wire = encr.header() + encr.encrypt_chunk(b"data: a\n\n", is_final=False) + dec = cli.ChunkedResponseDecrypter(enc_req.chunked_response_secret, enc_req.enc) + with pytest.raises(ValueError): + dec.push(wire, done=True) + + +def test_rejects_wrong_size_public_key(): + with pytest.raises(ValueError): + cli.encapsulate_request(b"too short", b"{}") diff --git a/tests/tee_verify_test.py b/tests/tee_verify_test.py new file mode 100644 index 0000000..d467f06 --- /dev/null +++ b/tests/tee_verify_test.py @@ -0,0 +1,158 @@ +"""Verify the TEE response verifier against an independently-built gateway signature. + +We reconstruct exactly what the gateway signs — RSA-PSS(salt=32, SHA256) over +``keccak256(inputHash || outputHash || uint256(ts))`` — and confirm +``verify_response`` accepts a good signature and rejects every tampered variant. +""" + +from __future__ import annotations + +import base64 +import json + +import pytest +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from eth_hash.auto import keccak + +from opengradient.client import tee_verify as verify +from opengradient.client.tee_verify import build_inner_request + + +def _make_key(): + priv = rsa.generate_private_key(public_exponent=65537, key_size=2048) + pem = ( + priv.public_key() + .public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + .decode() + ) + return priv, pem + + +def _sign(priv, canonical, output_content, timestamp): + request_bytes = json.dumps(canonical, sort_keys=True).encode("utf-8") + input_hash = keccak(request_bytes) + output_hash = keccak(output_content.encode("utf-8")) + msg_hash = keccak(input_hash + output_hash + timestamp.to_bytes(32, "big")) + sig = priv.sign( + msg_hash, + padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=32), + hashes.SHA256(), + ) + return { + "tee_signature": base64.b64encode(sig).decode(), + "tee_request_hash": input_hash.hex(), + "tee_output_hash": output_hash.hex(), + "tee_timestamp": timestamp, + } + + +def _good_case(): + priv, pem = _make_key() + tee_id = verify.tee_id_for_key(pem) + _wire, canonical = build_inner_request( + {"model": "gpt-4.1", "messages": [{"role": "user", "content": "Hello!"}]} + ) + content = "Hi there!" + response = { + "choices": [{"index": 0, "message": {"role": "assistant", "content": content}, "finish_reason": "stop"}], + **_sign(priv, canonical, content, 1_700_000_000), + "tee_id": tee_id, + } + return pem, tee_id, canonical, content, response + + +def test_valid_signature_verifies(): + pem, tee_id, canonical, content, response = _good_case() + proof = verify.verify_response( + canonical_request=canonical, + response_body=response, + response_content=content, + signing_key_pem=pem, + expected_tee_id=tee_id, + ) + assert proof.tee_id == tee_id + assert proof.timestamp == 1_700_000_000 + + +def test_tampered_content_is_rejected(): + pem, _tee_id, canonical, _content, response = _good_case() + with pytest.raises(verify.VerificationError, match="output hash"): + verify.verify_response( + canonical_request=canonical, + response_body=response, + response_content="Hi there! (tampered)", + signing_key_pem=pem, + ) + + +def test_tampered_request_is_rejected(): + pem, _tee_id, _canonical, content, response = _good_case() + other = build_inner_request({"model": "gpt-4.1", "messages": [{"role": "user", "content": "different"}]})[1] + with pytest.raises(verify.VerificationError, match="request hash"): + verify.verify_response( + canonical_request=other, + response_body=response, + response_content=content, + signing_key_pem=pem, + ) + + +def test_wrong_signing_key_is_rejected(): + _pem, _tee_id, canonical, content, response = _good_case() + _other_priv, other_pem = _make_key() + with pytest.raises(verify.VerificationError, match="tee_id mismatch"): + verify.verify_response( + canonical_request=canonical, + response_body=response, + response_content=content, + signing_key_pem=other_pem, + ) + + +def test_tool_call_output_hashing(): + priv, pem = _make_key() + tee_id = verify.tee_id_for_key(pem) + _wire, canonical = build_inner_request({"model": "gpt-4.1", "messages": [{"role": "user", "content": "weather?"}]}) + tool_calls = [{"id": "call_1", "type": "function", "function": {"name": "get", "arguments": "{}"}}] + output_content = json.dumps(tool_calls, sort_keys=True) + response = { + "choices": [{"index": 0, "message": {"role": "assistant", "tool_calls": tool_calls}, "finish_reason": "tool_calls"}], + **_sign(priv, canonical, output_content, 1_700_000_001), + "tee_id": tee_id, + } + assert verify.response_content_for_hash(response) == output_content + proof = verify.verify_response( + canonical_request=canonical, + response_body=response, + response_content=verify.response_content_for_hash(response), + signing_key_pem=pem, + ) + assert proof.timestamp == 1_700_000_001 + + +def test_build_inner_request_strips_attachment_bytes_from_hash_only(): + body = { + "model": "gpt-4.1", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "describe"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}}, + ], + } + ], + "temperature": 0.7, + "max_tokens": 256, + } + wire, canonical = build_inner_request(body) + # Wire keeps the bytes so the model sees the image... + assert wire["messages"][0]["content"][1]["image_url"]["url"].startswith("data:image/png") + # ...but the canonical (hashed) form commits only to type, not bytes. + assert canonical["messages"][0]["content"][1] == {"type": "image_url"} + assert canonical["temperature"] == 0.7 + assert canonical["max_tokens"] == 256 diff --git a/uv.lock b/uv.lock index 97d6f54..a3a339c 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11" resolution-markers = [ "python_full_version >= '3.14'", @@ -1909,10 +1909,11 @@ wheels = [ [[package]] name = "opengradient" -version = "1.0.7" +version = "1.0.8" source = { editable = "." } dependencies = [ { name = "click" }, + { name = "cryptography" }, { name = "eth-account" }, { name = "firebase-rest-api" }, { name = "langchain" }, @@ -1920,6 +1921,7 @@ dependencies = [ { name = "og-x402", extra = ["extensions"] }, { name = "openai" }, { name = "pydantic" }, + { name = "pyhpke" }, { name = "requests" }, { name = "requests-toolbelt" }, { name = "web3" }, @@ -1944,6 +1946,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "click", specifier = ">=8.1.7" }, + { name = "cryptography", specifier = ">=43.0.0" }, { name = "eth-account", specifier = ">=0.13.4" }, { name = "firebase-rest-api", specifier = ">=1.11.0" }, { name = "langchain", specifier = ">=0.3.7" }, @@ -1955,6 +1958,7 @@ requires-dist = [ { name = "openai", specifier = ">=1.58.1" }, { name = "pdoc3", marker = "extra == 'dev'", specifier = "==0.10.0" }, { name = "pydantic", specifier = ">=2.9.2" }, + { name = "pyhpke", specifier = ">=0.6.0" }, { name = "pytest", marker = "extra == 'dev'" }, { name = "pytest-asyncio", marker = "extra == 'dev'" }, { name = "requests", specifier = ">=2.32.3" }, @@ -2452,6 +2456,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyhpke" +version = "0.6.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/27/37/1acb2cee5afd3dcf45b425b0d984a9cba8917fd935106ef278b42062ecfa/pyhpke-0.6.4.tar.gz", hash = "sha256:1402c6c41a0605941d2d2a589774d346c0e7a0dc7f745e84c6f0a06c2fd335c9", size = 1638147, upload-time = "2025-12-21T10:38:07.556Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/f6/ff7df9e21b38ec1c827efd90c28b3bc76eddbfdf5a44aaf2fadb59a17cb9/pyhpke-0.6.4-py3-none-any.whl", hash = "sha256:abd0b2fec1424858399ffbed0d236fb7e9740dece9907f59ca40bd567d7fef78", size = 23792, upload-time = "2025-12-21T10:38:06.172Z" }, +] + [[package]] name = "pytest" version = "9.0.2" From c4447362ae749bc955b9c6772c449acd40e9bfa0 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 12 Jun 2026 16:07:08 +0000 Subject: [PATCH 2/3] Bump version to 1.1.0 for verified-inference client utilities Minor release adding the OHTTP client + signature-verification utilities (OhttpRelayClient, verify_response, build_inner_request, TEEEndpoint.pcr_hash). Lets downstream packages (e.g. opengradient-local) pin opengradient>=1.1.0. https://claude.ai/code/session_01PdYbDC47zuBGiZex7ZHMSs --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index de454d2..951ef7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "opengradient" -version = "1.0.8" +version = "1.1.0" description = "Python SDK for OpenGradient decentralized model management & inference services" authors = [{name = "OpenGradient", email = "adam@vannalabs.ai"}] readme = "README.md" diff --git a/uv.lock b/uv.lock index a3a339c..afabd8b 100644 --- a/uv.lock +++ b/uv.lock @@ -1909,7 +1909,7 @@ wheels = [ [[package]] name = "opengradient" -version = "1.0.8" +version = "1.1.0" source = { editable = "." } dependencies = [ { name = "click" }, From 29de49d901ac46a147ed75cfc8aaa75b4e9e7be5 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 12 Jun 2026 16:16:28 +0000 Subject: [PATCH 3/3] Address PR review: key_id rotation, stricter request validation, stream tool-calls, CI-runnable crypto tests - tee_ohttp.encapsulate_request: thread key_id (and validate kem/kdf/aead) from the registry OhttpConfig so a TEE that rotated its key_id still decapsulates. - tee_verify.build_inner_request: reject non-list 'tools' and non-dict messages with UnsupportedRequestError (was AttributeError / silent dict->keys coercion); preserve tool_choice on the wire while keeping it out of the signed hash. - OhttpRelayClient: wrap chunked-decrypt ValueError as VerificationError; pass the config's key/alg ids; aggregate streamed delta.tool_calls so honest tool-call streams verify against the gateway's signed json.dumps(tool_calls). - Tests: add a self-contained OHTTP recipient so encapsulation/decryption + end-to-end verification run in CI without a tee-gateway checkout (real gateway still cross-checked when present, via a sys.path-restoring fixture). https://claude.ai/code/session_01PdYbDC47zuBGiZex7ZHMSs --- src/opengradient/client/tee_ohttp.py | 34 ++++- src/opengradient/client/tee_ohttp_client.py | 99 +++++++++++---- src/opengradient/client/tee_verify.py | 25 +++- tests/_ohttp_recipient.py | 115 +++++++++++++++++ tests/conftest.py | 47 +++++++ tests/tee_ohttp_client_test.py | 131 ++++++++++++-------- tests/tee_ohttp_test.py | 85 +++++++------ tests/tee_verify_test.py | 25 ++++ 8 files changed, 437 insertions(+), 124 deletions(-) create mode 100644 tests/_ohttp_recipient.py create mode 100644 tests/conftest.py diff --git a/src/opengradient/client/tee_ohttp.py b/src/opengradient/client/tee_ohttp.py index d39f969..3e4f67c 100644 --- a/src/opengradient/client/tee_ohttp.py +++ b/src/opengradient/client/tee_ohttp.py @@ -53,8 +53,8 @@ ) -def _header_bytes() -> bytes: - return bytes([KEY_CONFIG_ID]) + struct.pack(">HHH", KEM_ID_X25519, KDF_ID_HKDF_SHA256, AEAD_ID_CHACHA20_POLY1305) +def _header_bytes(key_id: int = KEY_CONFIG_ID) -> bytes: + return bytes([key_id]) + struct.pack(">HHH", KEM_ID_X25519, KDF_ID_HKDF_SHA256, AEAD_ID_CHACHA20_POLY1305) @dataclass @@ -74,29 +74,51 @@ class EncapsulatedRequest: chunked_response_secret: bytes -def encapsulate_request(public_key_raw: bytes, plaintext: bytes) -> EncapsulatedRequest: +def encapsulate_request( + public_key_raw: bytes, + plaintext: bytes, + *, + key_id: int = KEY_CONFIG_ID, + kem_id: int = KEM_ID_X25519, + kdf_id: int = KDF_ID_HKDF_SHA256, + aead_id: int = AEAD_ID_CHACHA20_POLY1305, +) -> EncapsulatedRequest: """HPKE-seal ``plaintext`` to a TEE's raw X25519 public key. Args: public_key_raw: The 32-byte raw X25519 public key from the TEE's OHTTP config (``OhttpConfig.public_key``). plaintext: The inner request body (typically a UTF-8 JSON chat request). + key_id: The OHTTP key-config id from the TEE's ``OhttpConfig.key_id``. + Threaded into the request header so a TEE that rotated to a new + key_id (while keeping this suite) can still decapsulate. Defaults to + the canonical ``0x01``. + kem_id, kdf_id, aead_id: The HPKE algorithm ids from the TEE's config. + This client implements one fixed suite; mismatching ids are rejected + rather than silently producing an undecryptable request. Returns: An `EncapsulatedRequest` ready to send to a relay. Raises: - ValueError: If ``public_key_raw`` is not 32 bytes. + ValueError: If ``public_key_raw`` is not 32 bytes, or the algorithm ids + don't match this client's supported suite. """ if len(public_key_raw) != 32: raise ValueError("X25519 public key must be 32 bytes") + if (kem_id, kdf_id, aead_id) != (KEM_ID_X25519, KDF_ID_HKDF_SHA256, AEAD_ID_CHACHA20_POLY1305): + raise ValueError( + "unsupported HPKE suite " + f"(kem={kem_id:#06x}, kdf={kdf_id:#06x}, aead={aead_id:#06x}); " + "this client only implements DHKEM-X25519 / HKDF-SHA256 / ChaCha20-Poly1305" + ) pkr = _SUITE.kem.deserialize_public_key(public_key_raw) - info = _LABEL_REQUEST + b"\x00" + _header_bytes() + info = _LABEL_REQUEST + b"\x00" + _header_bytes(key_id) enc, sender = _SUITE.create_sender_context(pkr, info=info) ciphertext = sender.seal(plaintext, aad=b"") - wire = _header_bytes() + bytes(enc) + ciphertext + wire = _header_bytes(key_id) + bytes(enc) + ciphertext export_len = max(_NN, _NK) return EncapsulatedRequest( diff --git a/src/opengradient/client/tee_ohttp_client.py b/src/opengradient/client/tee_ohttp_client.py index e5593eb..02c500f 100644 --- a/src/opengradient/client/tee_ohttp_client.py +++ b/src/opengradient/client/tee_ohttp_client.py @@ -110,6 +110,14 @@ def __init__( self._relay_url = relay_url self._tee = tee self._ohttp_public_key = tee.ohttp_config.public_key + # Honor the registry's advertised key/algorithm ids (key rotation) rather + # than assuming the canonical defaults. + self._enc_ids = { + "key_id": tee.ohttp_config.key_id, + "kem_id": tee.ohttp_config.kem_id, + "kdf_id": tee.ohttp_config.kdf_id, + "aead_id": tee.ohttp_config.aead_id, + } self._auth_headers = auth_headers self._session = session or requests.Session() self._timeout = timeout @@ -130,7 +138,7 @@ def chat_completion(self, body: dict) -> VerifiedChatResponse: opengradient.client.tee_verify.UnsupportedRequestError: If the body is invalid. """ wire, canonical = build_inner_request(body) - enc = encapsulate_request(self._ohttp_public_key, json.dumps(wire).encode("utf-8")) + enc = encapsulate_request(self._ohttp_public_key, json.dumps(wire).encode("utf-8"), **self._enc_ids) resp = self._session.post( self._relay_url, @@ -180,7 +188,7 @@ def stream_chat_completion(self, body: dict) -> VerifiedChatResponse: """ wire, canonical = build_inner_request(body) wire = {**wire, "stream": True} - enc = encapsulate_request(self._ohttp_public_key, json.dumps(wire).encode("utf-8")) + enc = encapsulate_request(self._ohttp_public_key, json.dumps(wire).encode("utf-8"), **self._enc_ids) resp = self._session.post( self._relay_url, @@ -195,32 +203,45 @@ def stream_chat_completion(self, body: dict) -> VerifiedChatResponse: decrypter = ChunkedResponseDecrypter(enc.chunked_response_secret, enc.enc) frames: list[str] = [] full_content = "" + tool_calls: dict[int, dict] = {} final_frame: Optional[dict] = None chunks = resp.iter_content(chunk_size=8192) - for raw, is_last in _with_last(chunks): - for plaintext in decrypter.push(raw, done=is_last): - text = plaintext.decode("utf-8", errors="replace") - frames.append(text) - delta, final = _parse_sse_frame(text) - full_content += delta - if final is not None: - final_frame = final + try: + for raw, is_last in _with_last(chunks): + # A malformed/truncated encrypted stream is an integrity failure; + # surface it as VerificationError, not a raw ValueError. + for plaintext in decrypter.push(raw, done=is_last): + text = plaintext.decode("utf-8", errors="replace") + frames.append(text) + for parsed in _iter_sse_objects(text): + full_content += _delta_content(parsed) + _accumulate_tool_calls(tool_calls, parsed) + if isinstance(parsed.get("tee_signature"), str) or isinstance(parsed.get("tee_output_hash"), str): + final_frame = parsed + except ValueError as exc: + raise VerificationError(f"malformed TEE stream: {exc}") from exc if final_frame is None: raise VerificationError("TEE stream missing a signed final frame") + # The gateway signs the assistant text, except for tool-call responses + # where it signs json.dumps(tool_calls, sort_keys=True) of the buffered + # calls — mirror that so honest tool-call streams verify. + if _finish_reason(final_frame) == "tool_calls" and tool_calls: + response_content = json.dumps([tool_calls[i] for i in sorted(tool_calls)], sort_keys=True) + else: + response_content = full_content + proof = verify_response( canonical_request=canonical, response_body=final_frame, - response_content=full_content, + response_content=response_content, signing_key_pem=self._signing_key_pem, expected_tee_id=self._tee.tee_id, tee_host=self._tee.endpoint, ) - return VerifiedChatResponse( - body=final_frame, content=full_content, proof=proof, stream_frames=frames - ) + return VerifiedChatResponse(body=final_frame, content=response_content, proof=proof, stream_frames=frames) def _headers(self, *, stream: bool) -> dict: headers = { @@ -249,10 +270,8 @@ def _normalize_inner(decoded) -> tuple[int, dict]: raise VerificationError("malformed inner response") -def _parse_sse_frame(text: str) -> tuple[str, Optional[dict]]: - """Parse one decrypted SSE frame; return ``(delta_text, final_frame_or_None)``.""" - delta = "" - final: Optional[dict] = None +def _iter_sse_objects(text: str): + """Yield the parsed JSON objects from a decrypted SSE frame's ``data:`` lines.""" for line in text.splitlines(): line = line.strip() if not line.startswith("data:"): @@ -268,10 +287,46 @@ def _parse_sse_frame(text: str) -> tuple[str, Optional[dict]]: continue if isinstance(parsed.get("error"), str): raise RelayError(502, parsed["error"]) - delta += _delta_content(parsed) - if isinstance(parsed.get("tee_signature"), str) or isinstance(parsed.get("tee_output_hash"), str): - final = parsed - return delta, final + yield parsed + + +def _accumulate_tool_calls(buffered: dict[int, dict], frame: dict) -> None: + """Fold a streamed ``delta.tool_calls`` fragment into ``buffered`` (keyed by index). + + Mirrors the gateway's streaming tool-call buffer so the reconstructed list + matches the signed output: ids/names are set when present, argument fragments + are concatenated in arrival order. + """ + choices = frame.get("choices") + if not isinstance(choices, list) or not choices or not isinstance(choices[0], dict): + return + delta = choices[0].get("delta") + if not isinstance(delta, dict): + return + fragments = delta.get("tool_calls") + if not isinstance(fragments, list): + return + for frag in fragments: + if not isinstance(frag, dict): + continue + idx = frag.get("index", 0) + slot = buffered.setdefault(idx, {"id": "", "type": "function", "function": {"name": "", "arguments": ""}}) + if frag.get("id"): + slot["id"] = frag["id"] + if frag.get("type"): + slot["type"] = frag["type"] + fn = frag.get("function") or {} + if fn.get("name"): + slot["function"]["name"] = fn["name"] + if fn.get("arguments"): + slot["function"]["arguments"] += fn["arguments"] + + +def _finish_reason(frame: dict) -> Optional[str]: + choices = frame.get("choices") + if isinstance(choices, list) and choices and isinstance(choices[0], dict): + return choices[0].get("finish_reason") + return None def _delta_content(frame: dict) -> str: diff --git a/src/opengradient/client/tee_verify.py b/src/opengradient/client/tee_verify.py index da38c17..6288aee 100644 --- a/src/opengradient/client/tee_verify.py +++ b/src/opengradient/client/tee_verify.py @@ -190,6 +190,9 @@ def build_inner_request(body: dict) -> tuple[dict, dict]: messages = body.get("messages") if not isinstance(messages, list) or not messages: raise UnsupportedRequestError("request is missing 'messages'") + for m in messages: + if not isinstance(m, dict): + raise UnsupportedRequestError(f"each message must be a JSON object, got {type(m).__name__}") canonical_messages = [_canonical_message(m) for m in messages] @@ -205,20 +208,36 @@ def build_inner_request(body: dict) -> tuple[dict, dict]: canonical["stop"] = body["stop"] if body.get("tools"): tools = body["tools"] - canonical["tools"] = tools if isinstance(tools, list) else list(tools) + # The gateway hashes the tools list as-is; coercing a non-list (e.g. a + # dict) would change the wire request AND produce a hash that can never + # match. Reject it rather than guessing. + if not isinstance(tools, list): + raise UnsupportedRequestError(f"'tools' must be a list, got {type(tools).__name__}") + canonical["tools"] = tools if body.get("response_format"): canonical["response_format"] = body["response_format"] if body.get("web_search"): canonical["web_search"] = True + # The wire request is the canonical request with full (un-stripped) user + # content restored, plus any generation-affecting fields the gateway honors + # but does NOT fold into the signed hash (e.g. tool_choice) — dropping those + # would silently ignore caller intent. wire = copy.deepcopy(canonical) wire["messages"] = [ - ({**cm, "content": orig.get("content")} if cm.get("role") == "user" else cm) - for cm, orig in zip(wire["messages"], messages) + ({**cm, "content": orig.get("content")} if cm.get("role") == "user" else cm) for cm, orig in zip(wire["messages"], messages) ] + for field in _WIRE_ONLY_PASSTHROUGH: + if body.get(field) is not None: + wire[field] = body[field] return wire, canonical +# Fields forwarded to the gateway on the wire but intentionally excluded from the +# signed request hash (the gateway does not commit to them either). +_WIRE_ONLY_PASSTHROUGH = ("tool_choice",) + + def canonical_request_bytes(canonical_request: dict) -> bytes: """Serialize a canonical request exactly as the gateway hashes it: ``json.dumps(sort_keys=True)``.""" return json.dumps(canonical_request, sort_keys=True).encode("utf-8") diff --git a/tests/_ohttp_recipient.py b/tests/_ohttp_recipient.py new file mode 100644 index 0000000..6898319 --- /dev/null +++ b/tests/_ohttp_recipient.py @@ -0,0 +1,115 @@ +"""Self-contained OHTTP *recipient* for tests. + +A minimal, dependency-light port of the tee-gateway's recipient side +(``tee_gateway/ohttp.py``) using the same primitives this SDK already depends on +(pyhpke + cryptography). It lets the OHTTP wire-compatibility and end-to-end +verification tests run in CI **without** requiring an external tee-gateway +checkout. The real gateway is still cross-checked when available (see the +``real_tee_gateway`` fixture in ``conftest.py``). + +If this ever diverges from the gateway, the cross-check test fails — so it can't +silently rot. +""" + +from __future__ import annotations + +import os +import struct +from dataclasses import dataclass + +from cryptography.hazmat.primitives import hashes, hmac +from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 +from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand +from pyhpke import AEADId, CipherSuite, KDFId, KEMId + +KEY_CONFIG_ID = 0x01 +KEM_ID_X25519 = 0x0020 +KDF_ID_HKDF_SHA256 = 0x0001 +AEAD_ID_CHACHA20_POLY1305 = 0x0003 + +_NK = 32 +_NN = 12 + +_LABEL_REQUEST = b"message/bhttp request" +_LABEL_RESPONSE = b"message/bhttp response" +_LABEL_CHUNKED_RESPONSE = b"message/bhttp chunked response" + +_SUITE = CipherSuite.new(KEMId.DHKEM_X25519_HKDF_SHA256, KDFId.HKDF_SHA256, AEADId.CHACHA20_POLY1305) + + +def _header_bytes(key_id: int = KEY_CONFIG_ID) -> bytes: + return bytes([key_id]) + struct.pack(">HHH", KEM_ID_X25519, KDF_ID_HKDF_SHA256, AEAD_ID_CHACHA20_POLY1305) + + +def encode_varint(value: int) -> bytes: + if value < (1 << 6): + return bytes([value]) + if value < (1 << 14): + return bytes([0x40 | (value >> 8), value & 0xFF]) + if value < (1 << 30): + return struct.pack(">I", 0x80000000 | value) + return struct.pack(">Q", 0xC000000000000000 | value) + + +def generate_keypair(): + pair = _SUITE.kem.derive_key_pair(os.urandom(32)) + return pair.private_key, pair.public_key.to_public_bytes() + + +@dataclass +class DecapsulatedRequest: + plaintext: bytes + response_key: bytes + response_key_chunked: bytes + enc: bytes + + +def decapsulate_request(private_key, encapsulated_request: bytes) -> DecapsulatedRequest: + key_id = encapsulated_request[0] + enc = encapsulated_request[7 : 7 + 32] + aead_ct = encapsulated_request[7 + 32 :] + info = _LABEL_REQUEST + b"\x00" + _header_bytes(key_id) + recipient = _SUITE.create_recipient_context(enc, private_key, info=info) + plaintext = recipient.open(aead_ct, aad=b"") + export_len = max(_NN, _NK) + return DecapsulatedRequest( + plaintext=plaintext, + response_key=recipient.export(_LABEL_RESPONSE, export_len), + response_key_chunked=recipient.export(_LABEL_CHUNKED_RESPONSE, export_len), + enc=enc, + ) + + +def _derive_response_keys(response_secret: bytes, enc: bytes, response_nonce: bytes) -> tuple[bytes, bytes]: + h = hmac.HMAC(enc + response_nonce, hashes.SHA256()) + h.update(response_secret) + prk = h.finalize() + aead_key = HKDFExpand(algorithm=hashes.SHA256(), length=_NK, info=b"key").derive(prk) + aead_nonce = HKDFExpand(algorithm=hashes.SHA256(), length=_NN, info=b"nonce").derive(prk) + return aead_key, aead_nonce + + +def encapsulate_response(response_secret: bytes, enc: bytes, plaintext: bytes) -> bytes: + response_nonce = os.urandom(max(_NN, _NK)) + aead_key, aead_nonce = _derive_response_keys(response_secret, enc, response_nonce) + return response_nonce + ChaCha20Poly1305(aead_key).encrypt(aead_nonce, plaintext, b"") + + +class ChunkedResponseEncrypter: + def __init__(self, response_secret: bytes, enc: bytes): + self._response_nonce = os.urandom(max(_NN, _NK)) + self._aead_key, self._aead_nonce = _derive_response_keys(response_secret, enc, self._response_nonce) + self._aead = ChaCha20Poly1305(self._aead_key) + self._counter = 0 + + def header(self) -> bytes: + return self._response_nonce + + def encrypt_chunk(self, plaintext: bytes, is_final: bool) -> bytes: + ctr = self._counter.to_bytes(_NN, "big") + chunk_nonce = bytes(a ^ b for a, b in zip(self._aead_nonce, ctr)) + aad = b"final" if is_final else b"" + sealed = self._aead.encrypt(chunk_nonce, plaintext, aad) + self._counter += 1 + length_prefix = encode_varint(0) if is_final else encode_varint(len(sealed)) + return length_prefix + sealed diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..affdd10 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,47 @@ +"""Shared fixtures for the OHTTP / TEE-verification tests.""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path + +import _ohttp_recipient as _recipient # tests/ is on sys.path under pytest +import pytest + + +@pytest.fixture(scope="session") +def recipient(): + """A self-contained OHTTP recipient so crypto is exercised in CI without an + external tee-gateway checkout.""" + return _recipient + + +@pytest.fixture +def real_tee_gateway(): + """The real ``tee_gateway.ohttp`` module, for cross-checking wire compatibility. + + Skips when no tee-gateway checkout is present (``OG_TEE_GATEWAY`` or a sibling + ``../tee-gateway``). Restores ``sys.path`` afterwards so it can't perturb + import resolution in later tests. + """ + override = os.getenv("OG_TEE_GATEWAY") + candidates = [Path(override)] if override else [] + candidates.append(Path(__file__).resolve().parents[2] / "tee-gateway") + root = next((p for p in candidates if (p / "tee_gateway" / "ohttp.py").exists()), None) + if root is None: + pytest.skip("tee-gateway checkout not found (set OG_TEE_GATEWAY)") + + inserted = str(root) + sys.path.insert(0, inserted) + try: + import tee_gateway.ohttp as srv + + yield srv + finally: + try: + sys.path.remove(inserted) + except ValueError: + pass + for name in [n for n in sys.modules if n == "tee_gateway" or n.startswith("tee_gateway.")]: + del sys.modules[name] diff --git a/tests/tee_ohttp_client_test.py b/tests/tee_ohttp_client_test.py index 1e1e6e3..5cc90cb 100644 --- a/tests/tee_ohttp_client_test.py +++ b/tests/tee_ohttp_client_test.py @@ -1,19 +1,16 @@ -"""End-to-end test for the high-level OhttpRelayClient. +"""End-to-end tests for the high-level OhttpRelayClient. -Simulates the full path — client encrypts -> (fake relay+gateway) decapsulates, -signs, seals -> client decrypts + verifies — for both single-shot and streaming, -using the real tee-gateway recipient crypto. Skips if no tee-gateway checkout. +Simulates the full path — client encrypts -> (fake relay + recipient) decapsulates, +signs, seals -> client decrypts + verifies — for single-shot, streaming text, and +streaming tool calls, using the self-contained ``recipient`` fixture so it runs in +CI without an external checkout. """ from __future__ import annotations import base64 import json -import os -import sys -from pathlib import Path -import pytest from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding, rsa from eth_hash.auto import keccak @@ -23,27 +20,6 @@ from opengradient.client.tee_verify import build_inner_request, tee_id_for_key -def _load_server_ohttp(): - override = os.getenv("OG_TEE_GATEWAY") - candidates = [Path(override)] if override else [] - candidates.append(Path(__file__).resolve().parents[2] / "tee-gateway") - for root in candidates: - if (root / "tee_gateway" / "ohttp.py").exists(): - sys.path.insert(0, str(root)) - import tee_gateway.ohttp as srv - - return srv - return None - - -@pytest.fixture(scope="module") -def srv(): - s = _load_server_ohttp() - if s is None: - pytest.skip("tee-gateway checkout not found (set OG_TEE_GATEWAY)") - return s - - def _sign_fields(priv, canonical, output_content, ts): input_hash = keccak(json.dumps(canonical, sort_keys=True).encode()) output_hash = keccak(output_content.encode()) @@ -68,8 +44,8 @@ def iter_content(self, chunk_size=8192): yield from self._chunks -def _make_endpoint(srv): - hpke_priv, hpke_pub = srv.generate_keypair() +def _make_endpoint(recipient): + hpke_priv, hpke_pub = recipient.generate_keypair() rsa_priv = rsa.generate_private_key(public_exponent=65537, key_size=2048) der = rsa_priv.public_key().public_bytes( encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo @@ -88,8 +64,12 @@ def _make_endpoint(srv): return endpoint, hpke_priv, rsa_priv -def test_single_shot_roundtrip_and_verify(srv): - endpoint, hpke_priv, rsa_priv = _make_endpoint(srv) +def _session_with(fake_post): + return type("S", (), {"post": staticmethod(fake_post)})() + + +def test_single_shot_roundtrip_and_verify(recipient): + endpoint, hpke_priv, rsa_priv = _make_endpoint(recipient) body = {"model": "gpt-4.1", "messages": [{"role": "user", "content": "hi"}]} _wire, canonical = build_inner_request(body) content = "hello from the enclave" @@ -97,24 +77,21 @@ def test_single_shot_roundtrip_and_verify(srv): def fake_post(url, data=None, headers=None, timeout=None, stream=False): assert headers["X-TEE-ID"] == endpoint.tee_id assert headers["Authorization"] == "Bearer t0ken" - decap = srv.decapsulate_request(hpke_priv, data) + decap = recipient.decapsulate_request(hpke_priv, data) assert json.loads(decap.plaintext.decode())["model"] == "gpt-4.1" resp = { "choices": [{"index": 0, "message": {"role": "assistant", "content": content}, "finish_reason": "stop"}], **_sign_fields(rsa_priv, canonical, content, 1_700_000_000), "tee_id": endpoint.tee_id, } - sealed = srv.encapsulate_response(decap.response_key, decap.enc, json.dumps(resp).encode()) + sealed = recipient.encapsulate_response(decap.response_key, decap.enc, json.dumps(resp).encode()) return _FakeResp(content=sealed) - class _Sess: - post = staticmethod(fake_post) - client = OhttpRelayClient( "https://relay/api/v1/chat/ohttp", endpoint, auth_headers=lambda: {"Authorization": "Bearer t0ken"}, - session=_Sess(), + session=_session_with(fake_post), ) result = client.chat_completion(body) assert result.content == content @@ -122,17 +99,16 @@ class _Sess: assert result.proof.timestamp == 1_700_000_000 -def test_streaming_roundtrip_and_verify(srv): - endpoint, hpke_priv, rsa_priv = _make_endpoint(srv) +def test_streaming_roundtrip_and_verify(recipient): + endpoint, hpke_priv, rsa_priv = _make_endpoint(recipient) body = {"model": "gpt-4.1", "messages": [{"role": "user", "content": "hi"}]} _wire, canonical = build_inner_request(body) full = "Hello world" def fake_post(url, data=None, headers=None, timeout=None, stream=False): - assert stream is True - assert headers["X-OHTTP-Stream"] == "true" - decap = srv.decapsulate_request(hpke_priv, data) - encr = srv.ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) + assert stream is True and headers["X-OHTTP-Stream"] == "true" + decap = recipient.decapsulate_request(hpke_priv, data) + encr = recipient.ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) wire = encr.header() wire += encr.encrypt_chunk(b'data: {"choices":[{"delta":{"content":"Hello "},"index":0}]}\n\n', is_final=False) wire += encr.encrypt_chunk(b'data: {"choices":[{"delta":{"content":"world"},"index":0}]}\n\n', is_final=False) @@ -142,15 +118,70 @@ def fake_post(url, data=None, headers=None, timeout=None, stream=False): "tee_id": endpoint.tee_id, } wire += encr.encrypt_chunk(f"data: {json.dumps(final)}\n\n".encode(), is_final=True) - # Deliver as a couple of network chunks to exercise buffering. mid = len(wire) // 2 return _FakeResp(chunks=[wire[:mid], wire[mid:]]) - class _Sess: - post = staticmethod(fake_post) - - client = OhttpRelayClient("https://relay/api/v1/chat/ohttp", endpoint, session=_Sess()) + client = OhttpRelayClient("https://relay/api/v1/chat/ohttp", endpoint, session=_session_with(fake_post)) result = client.stream_chat_completion(body) assert result.content == full assert result.proof.timestamp == 1_700_000_001 assert result.stream_frames is not None and len(result.stream_frames) == 3 + + +def test_streaming_tool_calls_verify(recipient): + """Tool calls streamed as deltas must reconstruct the gateway's signed output.""" + endpoint, hpke_priv, rsa_priv = _make_endpoint(recipient) + body = {"model": "gpt-4.1", "messages": [{"role": "user", "content": "weather?"}]} + _wire, canonical = build_inner_request(body) + # What the gateway signs for a tool-call response: + tool_calls = [{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": '{"city":"NYC"}'}}] + signed_output = json.dumps(tool_calls, sort_keys=True) + + def fake_post(url, data=None, headers=None, timeout=None, stream=False): + decap = recipient.decapsulate_request(hpke_priv, data) + encr = recipient.ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) + wire = encr.header() + + # Stream the tool call as fragments (id+name first, then argument chunks). + def _tc(frag): + return {"choices": [{"delta": {"tool_calls": [frag]}, "index": 0}]} + + f1 = _tc({"index": 0, "id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": ""}}) + f2 = _tc({"index": 0, "function": {"arguments": '{"city":'}}) + f3 = _tc({"index": 0, "function": {"arguments": '"NYC"}'}}) + wire += encr.encrypt_chunk(f"data: {json.dumps(f1)}\n\n".encode(), is_final=False) + wire += encr.encrypt_chunk(f"data: {json.dumps(f2)}\n\n".encode(), is_final=False) + wire += encr.encrypt_chunk(f"data: {json.dumps(f3)}\n\n".encode(), is_final=False) + final = { + "choices": [{"delta": {}, "index": 0, "finish_reason": "tool_calls"}], + **_sign_fields(rsa_priv, canonical, signed_output, 1_700_000_002), + "tee_id": endpoint.tee_id, + } + wire += encr.encrypt_chunk(f"data: {json.dumps(final)}\n\n".encode(), is_final=True) + return _FakeResp(chunks=[wire]) + + client = OhttpRelayClient("https://relay/api/v1/chat/ohttp", endpoint, session=_session_with(fake_post)) + result = client.stream_chat_completion(body) + assert result.content == signed_output + assert result.proof.timestamp == 1_700_000_002 + + +def test_malformed_stream_raises_verification_error(recipient): + from opengradient.client import VerificationError + + endpoint, hpke_priv, _rsa = _make_endpoint(recipient) + body = {"model": "gpt-4.1", "messages": [{"role": "user", "content": "hi"}]} + + def fake_post(url, data=None, headers=None, timeout=None, stream=False): + decap = recipient.decapsulate_request(hpke_priv, data) + encr = recipient.ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) + # No final marker -> truncated stream. + wire = encr.header() + encr.encrypt_chunk(b"data: x\n\n", is_final=False) + return _FakeResp(chunks=[wire]) + + client = OhttpRelayClient("https://relay/api/v1/chat/ohttp", endpoint, session=_session_with(fake_post)) + try: + client.stream_chat_completion(body) + assert False, "expected VerificationError" + except VerificationError: + pass diff --git a/tests/tee_ohttp_test.py b/tests/tee_ohttp_test.py index 87ec923..ac097a0 100644 --- a/tests/tee_ohttp_test.py +++ b/tests/tee_ohttp_test.py @@ -1,64 +1,38 @@ -"""Wire-compatibility round-trips for the client OHTTP encapsulation. +"""OHTTP client wire-compatibility round-trips. -When a tee-gateway checkout is available (``OG_TEE_GATEWAY`` env var, or a sibling -``../tee-gateway``), these round-trip our client crypto against the *actual* -server recipient code, guaranteeing the two stay byte-compatible. The gateway's -``tee_gateway/ohttp.py`` only needs ``cryptography`` + ``pyhpke``, so importing it -standalone is cheap. The tests skip cleanly when no checkout is found. +These run in CI against a self-contained in-repo recipient (the ``recipient`` +fixture), so the encapsulation/decryption code is always exercised. When a real +tee-gateway checkout is present, ``test_cross_check_against_real_gateway`` also +round-trips against the actual server code to catch any drift. """ from __future__ import annotations -import os -import sys -from pathlib import Path - import pytest from opengradient.client import tee_ohttp as cli -def _load_server_ohttp(): - override = os.getenv("OG_TEE_GATEWAY") - candidates = [Path(override)] if override else [] - candidates.append(Path(__file__).resolve().parents[2] / "tee-gateway") - for root in candidates: - if (root / "tee_gateway" / "ohttp.py").exists(): - sys.path.insert(0, str(root)) - import tee_gateway.ohttp as srv - - return srv - return None - - -@pytest.fixture(scope="module") -def server_ohttp(): - srv = _load_server_ohttp() - if srv is None: - pytest.skip("tee-gateway checkout not found (set OG_TEE_GATEWAY)") - return srv - - -def test_request_and_single_shot_response(server_ohttp): - priv, pub_raw = server_ohttp.generate_keypair() +def test_request_and_single_shot_response(recipient): + priv, pub_raw = recipient.generate_keypair() plaintext = b'{"model":"gpt-4.1","messages":[{"role":"user","content":"hi"}]}' enc_req = cli.encapsulate_request(pub_raw, plaintext) - decap = server_ohttp.decapsulate_request(priv, enc_req.wire) + decap = recipient.decapsulate_request(priv, enc_req.wire) assert decap.plaintext == plaintext assert decap.enc == enc_req.enc resp_pt = b'{"choices":[{"message":{"content":"hello"}}]}' - sealed = server_ohttp.encapsulate_response(decap.response_key, decap.enc, resp_pt) + sealed = recipient.encapsulate_response(decap.response_key, decap.enc, resp_pt) assert cli.decrypt_response(enc_req.response_secret, enc_req.enc, sealed) == resp_pt -def test_chunked_response_whole_and_incremental(server_ohttp): - priv, pub_raw = server_ohttp.generate_keypair() +def test_chunked_response_whole_and_incremental(recipient): + priv, pub_raw = recipient.generate_keypair() enc_req = cli.encapsulate_request(pub_raw, b"{}") - decap = server_ohttp.decapsulate_request(priv, enc_req.wire) + decap = recipient.decapsulate_request(priv, enc_req.wire) - encr = server_ohttp.ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) + encr = recipient.ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) wire = encr.header() frames = [b"data: a\n\n", b"data: b\n\n", b"data: [DONE]\n\n"] wire += encr.encrypt_chunk(frames[0], is_final=False) @@ -75,11 +49,11 @@ def test_chunked_response_whole_and_incremental(server_ohttp): assert out == frames -def test_truncated_chunked_stream_is_rejected(server_ohttp): - priv, pub_raw = server_ohttp.generate_keypair() +def test_truncated_chunked_stream_is_rejected(recipient): + priv, pub_raw = recipient.generate_keypair() enc_req = cli.encapsulate_request(pub_raw, b"{}") - decap = server_ohttp.decapsulate_request(priv, enc_req.wire) - encr = server_ohttp.ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) + decap = recipient.decapsulate_request(priv, enc_req.wire) + encr = recipient.ChunkedResponseEncrypter(decap.response_key_chunked, decap.enc) wire = encr.header() + encr.encrypt_chunk(b"data: a\n\n", is_final=False) dec = cli.ChunkedResponseDecrypter(enc_req.chunked_response_secret, enc_req.enc) with pytest.raises(ValueError): @@ -89,3 +63,28 @@ def test_truncated_chunked_stream_is_rejected(server_ohttp): def test_rejects_wrong_size_public_key(): with pytest.raises(ValueError): cli.encapsulate_request(b"too short", b"{}") + + +def test_rejects_unsupported_suite(): + with pytest.raises(ValueError, match="unsupported HPKE suite"): + cli.encapsulate_request(b"\x00" * 32, b"{}", aead_id=0x0001) + + +def test_custom_key_id_round_trips(recipient): + # A TEE that rotated to key_id=0x07 must still decapsulate (the id is carried + # in the header and bound into the HPKE info string on both sides). + priv, pub_raw = recipient.generate_keypair() + enc_req = cli.encapsulate_request(pub_raw, b"{}", key_id=0x07) + assert enc_req.wire[0] == 0x07 + decap = recipient.decapsulate_request(priv, enc_req.wire) + assert decap.plaintext == b"{}" + + +def test_cross_check_against_real_gateway(real_tee_gateway): + """When a tee-gateway checkout is present, confirm we're byte-compatible with it.""" + priv, pub_raw = real_tee_gateway.generate_keypair() + enc_req = cli.encapsulate_request(pub_raw, b'{"ping":1}') + decap = real_tee_gateway.decapsulate_request(priv, enc_req.wire) + assert decap.plaintext == b'{"ping":1}' + sealed = real_tee_gateway.encapsulate_response(decap.response_key, decap.enc, b'{"pong":1}') + assert cli.decrypt_response(enc_req.response_secret, enc_req.enc, sealed) == b'{"pong":1}' diff --git a/tests/tee_verify_test.py b/tests/tee_verify_test.py index d467f06..611a51d 100644 --- a/tests/tee_verify_test.py +++ b/tests/tee_verify_test.py @@ -156,3 +156,28 @@ def test_build_inner_request_strips_attachment_bytes_from_hash_only(): assert canonical["messages"][0]["content"][1] == {"type": "image_url"} assert canonical["temperature"] == 0.7 assert canonical["max_tokens"] == 256 + + +def test_non_list_tools_rejected(): + with pytest.raises(verify.UnsupportedRequestError, match="tools"): + build_inner_request( + {"model": "gpt-4.1", "messages": [{"role": "user", "content": "x"}], "tools": {"a": 1}} + ) + + +def test_non_dict_message_rejected(): + with pytest.raises(verify.UnsupportedRequestError, match="message must be"): + build_inner_request({"model": "gpt-4.1", "messages": ["not a dict"]}) + + +def test_tool_choice_preserved_on_wire_but_not_hashed(): + body = { + "model": "gpt-4.1", + "messages": [{"role": "user", "content": "x"}], + "tool_choice": "auto", + } + wire, canonical = build_inner_request(body) + # Forwarded to the gateway so caller intent isn't silently dropped... + assert wire["tool_choice"] == "auto" + # ...but excluded from the signed request hash (the gateway doesn't commit to it). + assert "tool_choice" not in canonical