mirror of
https://github.com/JJTech0130/pypush.git
synced 2025-01-22 11:18:29 +00:00
Rewrite: APNs: Scoped App Tokens (#101)
+ Adds a lot of new API surface around filters + Adds CI type checking and linting
This commit is contained in:
parent
b1c30a98ff
commit
ce88a8ea5c
18
.github/workflows/pyright.yml
vendored
Normal file
18
.github/workflows/pyright.yml
vendored
Normal file
@ -0,0 +1,18 @@
|
||||
name: Pyright
|
||||
on: [push, pull_request]
|
||||
jobs:
|
||||
pyright:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
cache: 'pip'
|
||||
|
||||
- run: |
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -e '.[test,cli]'
|
||||
|
||||
- run: echo "$PWD/.venv/bin" >> $GITHUB_PATH
|
||||
- uses: jakebailey/pyright-action@v2
|
8
.github/workflows/ruff.yml
vendored
Normal file
8
.github/workflows/ruff.yml
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
name: Ruff
|
||||
on: [push, pull_request]
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: chartboost/ruff-action@v1
|
@ -34,4 +34,8 @@ version_file = "pypush/_version.py"
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "6.0"
|
||||
addopts = ["-ra", "-q"]
|
||||
testpaths = ["tests"]
|
||||
testpaths = ["tests"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "B", "SIM", "I"]
|
||||
ignore = ["E501", "B010"]
|
@ -1,5 +1,5 @@
|
||||
__all__ = ["protocol", "create_apns_connection", "activate"]
|
||||
__all__ = ["protocol", "create_apns_connection", "activate", "filters"]
|
||||
|
||||
from . import protocol
|
||||
from .lifecycle import create_apns_connection
|
||||
from . import filters, protocol
|
||||
from .albert import activate
|
||||
from .lifecycle import create_apns_connection
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from dataclasses import MISSING, field
|
||||
from dataclasses import fields as dataclass_fields
|
||||
from typing import Any, TypeVar, get_origin, get_args, Union
|
||||
from typing import Any, TypeVar, Union, get_args, get_origin
|
||||
|
||||
from pypush.apns.transport import Packet
|
||||
|
||||
@ -67,14 +67,14 @@ def command(cls: T) -> T:
|
||||
)
|
||||
|
||||
# Check for extra fields
|
||||
for field in packet.fields:
|
||||
if field.id not in [
|
||||
for current_field in packet.fields:
|
||||
if current_field.id not in [
|
||||
f.metadata["packet_id"]
|
||||
for f in dataclass_fields(cls)
|
||||
if f.metadata is not None and "packet_id" in f.metadata
|
||||
]:
|
||||
logging.warning(
|
||||
f"Unexpected field with packet ID {field.id} in packet {packet}"
|
||||
f"Unexpected field with packet ID {current_field.id} in packet {packet}"
|
||||
)
|
||||
return cls(**field_values)
|
||||
|
||||
@ -122,15 +122,15 @@ def fid(
|
||||
:param byte_len: The length of the field in bytes (for int fields)
|
||||
:param default: The default value of the field
|
||||
"""
|
||||
if not default == MISSING and not default_factory == MISSING:
|
||||
if default != MISSING and default_factory != MISSING:
|
||||
raise ValueError("Cannot specify both default and default_factory")
|
||||
if not default == MISSING:
|
||||
if default != MISSING:
|
||||
return field(
|
||||
metadata={"packet_id": packet_id, "packet_bytes": byte_len},
|
||||
default=default,
|
||||
repr=repr,
|
||||
)
|
||||
if not default_factory == MISSING:
|
||||
if default_factory != MISSING:
|
||||
return field(
|
||||
metadata={"packet_id": packet_id, "packet_bytes": byte_len},
|
||||
default_factory=default_factory,
|
||||
|
@ -3,25 +3,40 @@ from contextlib import asynccontextmanager
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import anyio
|
||||
from anyio.abc import ObjectSendStream
|
||||
from anyio.abc import ObjectReceiveStream, ObjectSendStream
|
||||
|
||||
from . import filters
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BroadcastStream(Generic[T]):
|
||||
def __init__(self):
|
||||
def __init__(self, backlog: int = 50):
|
||||
self.streams: list[ObjectSendStream[T]] = []
|
||||
self.backlog: list[T] = []
|
||||
self._backlog_size = backlog
|
||||
|
||||
async def broadcast(self, packet):
|
||||
logging.debug(f"Broadcasting {packet} to {len(self.streams)} streams")
|
||||
for stream in self.streams:
|
||||
try:
|
||||
await stream.send(packet)
|
||||
except anyio.BrokenResourceError:
|
||||
self.streams.remove(stream)
|
||||
logging.error("Broken resource error")
|
||||
# self.streams.remove(stream)
|
||||
# If we have a backlog, add the packet to it
|
||||
if len(self.backlog) >= self._backlog_size:
|
||||
self.backlog.pop(0)
|
||||
self.backlog.append(packet)
|
||||
|
||||
@asynccontextmanager
|
||||
async def open_stream(self):
|
||||
send, recv = anyio.create_memory_object_stream[T]()
|
||||
async def open_stream(self, backlog: bool = True):
|
||||
# 1000 seems like a reasonable number, if more than 1000 messages come in before someone deals with them it will
|
||||
# start stalling the APNs connection itself
|
||||
send, recv = anyio.create_memory_object_stream[T](max_buffer_size=1000)
|
||||
if backlog:
|
||||
for packet in self.backlog:
|
||||
await send.send(packet)
|
||||
self.streams.append(send)
|
||||
async with recv:
|
||||
yield recv
|
||||
@ -29,6 +44,31 @@ class BroadcastStream(Generic[T]):
|
||||
await send.aclose()
|
||||
|
||||
|
||||
W = TypeVar("W")
|
||||
F = TypeVar("F")
|
||||
|
||||
|
||||
class FilteredStream(ObjectReceiveStream[F]):
|
||||
"""
|
||||
A stream that filters out unwanted items
|
||||
|
||||
filter should return None if the item should be filtered out, otherwise it should return the item or a modified version of it
|
||||
"""
|
||||
|
||||
def __init__(self, source: ObjectReceiveStream[W], filter: filters.Filter[W, F]):
|
||||
self.source = source
|
||||
self.filter = filter
|
||||
|
||||
async def receive(self) -> F:
|
||||
async for item in self.source:
|
||||
if (filtered := self.filter(item)) is not None:
|
||||
return filtered
|
||||
raise anyio.EndOfStream
|
||||
|
||||
async def aclose(self):
|
||||
await self.source.aclose()
|
||||
|
||||
|
||||
def exponential_backoff(f):
|
||||
async def wrapper(*args, **kwargs):
|
||||
backoff = 1
|
||||
|
@ -4,7 +4,7 @@ import plistlib
|
||||
import re
|
||||
import uuid
|
||||
from base64 import b64decode
|
||||
from typing import Tuple, Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from cryptography import x509
|
||||
@ -96,10 +96,10 @@ async def activate(
|
||||
|
||||
try:
|
||||
protocol = re.search("<Protocol>(.*)</Protocol>", resp.text).group(1) # type: ignore
|
||||
except AttributeError:
|
||||
except AttributeError as e:
|
||||
# Search for error text between <b> and </b>
|
||||
error = re.search("<b>(.*)</b>", resp.text).group(1) # type: ignore
|
||||
raise Exception(f"Failed to get certificate from Albert: {error}")
|
||||
raise Exception(f"Failed to get certificate from Albert: {error}") from e
|
||||
|
||||
protocol = plistlib.loads(protocol.encode("utf-8"))
|
||||
|
||||
|
44
pypush/apns/filters.py
Normal file
44
pypush/apns/filters.py
Normal file
@ -0,0 +1,44 @@
|
||||
import logging
|
||||
from typing import Callable, Optional, Type, TypeVar
|
||||
|
||||
from pypush.apns import protocol
|
||||
|
||||
T1 = TypeVar("T1")
|
||||
T2 = TypeVar("T2")
|
||||
Filter = Callable[[T1], Optional[T2]]
|
||||
|
||||
# Chain with proper types so that subsequent filters only need to take output type of previous filter
|
||||
T_IN = TypeVar("T_IN", bound=protocol.Command)
|
||||
T_MIDDLE = TypeVar("T_MIDDLE", bound=protocol.Command)
|
||||
T_OUT = TypeVar("T_OUT", bound=protocol.Command)
|
||||
|
||||
|
||||
def chain(first: Filter[T_IN, T_MIDDLE], second: Filter[T_MIDDLE, T_OUT]):
|
||||
def filter(command: T_IN) -> Optional[T_OUT]:
|
||||
logging.debug(f"Filtering {command} with {first} and {second}")
|
||||
filtered = first(command)
|
||||
if filtered is None:
|
||||
return None
|
||||
return second(filtered)
|
||||
|
||||
return filter
|
||||
|
||||
|
||||
T = TypeVar("T", bound=protocol.Command)
|
||||
|
||||
|
||||
def cmd(type: Type[T]) -> Filter[protocol.Command, T]:
|
||||
def filter(command: protocol.Command) -> Optional[T]:
|
||||
if isinstance(command, type):
|
||||
return command
|
||||
return None
|
||||
|
||||
return filter
|
||||
|
||||
|
||||
def ALL(c):
|
||||
return c
|
||||
|
||||
|
||||
def NONE(_):
|
||||
return None
|
@ -6,6 +6,7 @@ import random
|
||||
import time
|
||||
import typing
|
||||
from contextlib import asynccontextmanager
|
||||
from hashlib import sha1
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
@ -13,7 +14,7 @@ from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||
|
||||
from . import protocol, transport, _util
|
||||
from . import _util, filters, protocol, transport
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@ -21,13 +22,18 @@ async def create_apns_connection(
|
||||
certificate: x509.Certificate,
|
||||
private_key: rsa.RSAPrivateKey,
|
||||
token: typing.Optional[bytes] = None,
|
||||
sandbox: bool = False,
|
||||
courier: typing.Optional[str] = None,
|
||||
):
|
||||
async with anyio.create_task_group() as tg:
|
||||
conn = Connection(tg, certificate, private_key, token, courier)
|
||||
conn = Connection(
|
||||
tg, certificate, private_key, token, sandbox, courier
|
||||
) # Await connected for first time here, so that base token is set
|
||||
yield conn
|
||||
tg.cancel_scope.cancel() # Cancel the task group when the context manager exits
|
||||
await conn.aclose() # Make sure to close the connection after the task group is cancelled
|
||||
await (
|
||||
conn.aclose()
|
||||
) # Make sure to close the connection after the task group is cancelled
|
||||
|
||||
|
||||
class Connection:
|
||||
@ -37,26 +43,44 @@ class Connection:
|
||||
certificate: x509.Certificate,
|
||||
private_key: rsa.RSAPrivateKey,
|
||||
token: typing.Optional[bytes] = None,
|
||||
sandbox: bool = False,
|
||||
courier: typing.Optional[str] = None,
|
||||
):
|
||||
|
||||
self.certificate = certificate
|
||||
self.private_key = private_key
|
||||
self.base_token = token
|
||||
self._base_token = token
|
||||
|
||||
self._filters: dict[str, int] = {} # topic -> use count
|
||||
|
||||
self._connected = anyio.Event() # Only use for base_token property
|
||||
|
||||
self._conn = None
|
||||
self._tg = task_group
|
||||
self._broadcast = _util.BroadcastStream[protocol.Command]()
|
||||
self._reconnect_lock = anyio.Lock()
|
||||
self._send_lock = anyio.Lock()
|
||||
|
||||
self.sandbox = sandbox
|
||||
if courier is None:
|
||||
# Pick a random courier server from 1 to 50
|
||||
courier = f"{random.randint(1, 50)}-courier.push.apple.com"
|
||||
courier = (
|
||||
f"{random.randint(1, 50)}-courier.push.apple.com"
|
||||
if not sandbox
|
||||
else f"{random.randint(1, 10)}-courier.sandbox.push.apple.com"
|
||||
)
|
||||
logging.debug(f"Using courier: {courier}")
|
||||
self.courier = courier
|
||||
|
||||
self._tg.start_soon(self.reconnect)
|
||||
self._tg.start_soon(self._ping_task)
|
||||
|
||||
@property
|
||||
async def base_token(self) -> bytes:
|
||||
if self._base_token is None:
|
||||
await self._connected.wait()
|
||||
assert self._base_token is not None
|
||||
return self._base_token
|
||||
|
||||
async def _receive_task(self):
|
||||
assert self._conn is not None
|
||||
async for command in self._conn:
|
||||
@ -68,8 +92,10 @@ class Connection:
|
||||
while True:
|
||||
await anyio.sleep(30)
|
||||
logging.debug("Sending keepalive")
|
||||
await self.send(protocol.KeepAliveCommand())
|
||||
await self.receive(protocol.KeepAliveAck)
|
||||
await self._send(protocol.KeepAliveCommand())
|
||||
await self._receive(
|
||||
filters.cmd(protocol.KeepAliveAck), backlog=False
|
||||
) # Explicitly disable the backlog since we don't want to receive old acks
|
||||
|
||||
@_util.exponential_backoff
|
||||
async def reconnect(self):
|
||||
@ -77,8 +103,11 @@ class Connection:
|
||||
if self._conn is not None:
|
||||
logging.warning("Closing existing connection")
|
||||
await self._conn.aclose()
|
||||
self._conn = protocol.CommandStream(
|
||||
await transport.create_courier_connection(courier=self.courier)
|
||||
|
||||
self._broadcast.backlog = [] # Clear the backlog
|
||||
|
||||
conn = protocol.CommandStream(
|
||||
await transport.create_courier_connection(self.sandbox, self.courier)
|
||||
)
|
||||
cert = self.certificate.public_bytes(serialization.Encoding.DER)
|
||||
nonce = (
|
||||
@ -89,53 +118,150 @@ class Connection:
|
||||
signature = b"\x01\x01" + self.private_key.sign(
|
||||
nonce, padding.PKCS1v15(), hashes.SHA1()
|
||||
)
|
||||
await self._conn.send(
|
||||
await conn.send(
|
||||
protocol.ConnectCommand(
|
||||
push_token=self.base_token,
|
||||
push_token=self._base_token,
|
||||
state=1,
|
||||
flags=69,
|
||||
flags=65, # 69
|
||||
certificate=cert,
|
||||
nonce=nonce,
|
||||
signature=signature,
|
||||
)
|
||||
)
|
||||
|
||||
# Don't set self._conn until we've sent the connect command
|
||||
self._conn = conn
|
||||
|
||||
self._tg.start_soon(self._receive_task)
|
||||
ack = await self.receive(protocol.ConnectAck)
|
||||
ack = await self._receive(
|
||||
filters.chain(
|
||||
filters.cmd(protocol.ConnectAck),
|
||||
lambda c: (
|
||||
c
|
||||
if (
|
||||
c.token == self._base_token
|
||||
if self._base_token is not None
|
||||
else True
|
||||
)
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
logging.debug(f"Connected with ack: {ack}")
|
||||
assert ack.status == 0
|
||||
if self.base_token is None:
|
||||
self.base_token = ack.token
|
||||
if self._base_token is None:
|
||||
self._base_token = ack.token
|
||||
else:
|
||||
assert ack.token == self.base_token
|
||||
assert ack.token == self._base_token
|
||||
if not self._connected.is_set():
|
||||
self._connected.set()
|
||||
|
||||
await self._update_filter()
|
||||
|
||||
async def aclose(self):
|
||||
if self._conn is not None:
|
||||
await self._conn.aclose()
|
||||
# Note: Will be reopened if task group is still running and ping task is still running
|
||||
|
||||
T = typing.TypeVar("T", bound=protocol.Command)
|
||||
T = typing.TypeVar("T")
|
||||
|
||||
async def receive_stream(
|
||||
self, filter: typing.Type[T], max: int = -1
|
||||
) -> typing.AsyncIterator[T]:
|
||||
async with self._broadcast.open_stream() as stream:
|
||||
@asynccontextmanager
|
||||
async def _receive_stream(
|
||||
self,
|
||||
filter: filters.Filter[protocol.Command, T] = lambda c: c,
|
||||
backlog: bool = True,
|
||||
):
|
||||
async with self._broadcast.open_stream(backlog) as stream:
|
||||
yield _util.FilteredStream(stream, filter)
|
||||
|
||||
async def _receive(
|
||||
self, filter: filters.Filter[protocol.Command, T], backlog: bool = True
|
||||
):
|
||||
async with self._receive_stream(filter, backlog) as stream:
|
||||
async for command in stream:
|
||||
if isinstance(command, filter):
|
||||
yield command
|
||||
max -= 1
|
||||
if max == 0:
|
||||
break
|
||||
return command
|
||||
raise ValueError("Did not receive expected command")
|
||||
|
||||
async def receive(self, filter: typing.Type[T]) -> T:
|
||||
async for command in self.receive_stream(filter, 1):
|
||||
return command
|
||||
raise ValueError("No matching command received")
|
||||
|
||||
async def send(self, command: protocol.Command):
|
||||
async def _send(self, command: protocol.Command):
|
||||
try:
|
||||
assert self._conn is not None
|
||||
await self._conn.send(command)
|
||||
except Exception as e:
|
||||
logging.warning(f"Error sending command, reconnecting")
|
||||
async with self._send_lock:
|
||||
assert self._conn is not None
|
||||
await self._conn.send(command)
|
||||
except Exception:
|
||||
logging.warning("Error sending command, reconnecting")
|
||||
await self.reconnect()
|
||||
await self.send(command)
|
||||
await self._send(command)
|
||||
|
||||
async def _update_filter(self):
|
||||
await self._send(
|
||||
protocol.FilterCommand(
|
||||
token=await self.base_token,
|
||||
enabled_topic_hashes=[
|
||||
sha1(topic.encode()).digest() for topic in self._filters
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _filter(self, topics: list[str]):
|
||||
for topic in topics:
|
||||
self._filters[topic] = self._filters.get(topic, 0) + 1
|
||||
await self._update_filter()
|
||||
yield
|
||||
for topic in topics:
|
||||
self._filters[topic] -= 1
|
||||
if self._filters[topic] == 0:
|
||||
del self._filters[topic]
|
||||
await self._update_filter()
|
||||
|
||||
async def mint_scoped_token(self, topic: str) -> bytes:
|
||||
topic_hash = sha1(topic.encode()).digest()
|
||||
await self._send(
|
||||
protocol.ScopedTokenCommand(token=await self.base_token, topic=topic_hash)
|
||||
)
|
||||
ack = await self._receive(filters.cmd(protocol.ScopedTokenAck))
|
||||
assert ack.status == 0
|
||||
return ack.scoped_token
|
||||
|
||||
@asynccontextmanager
|
||||
async def notification_stream(
|
||||
self,
|
||||
topic: str,
|
||||
token: typing.Optional[bytes] = None,
|
||||
filter: filters.Filter[
|
||||
protocol.SendMessageCommand, protocol.SendMessageCommand
|
||||
] = filters.ALL,
|
||||
):
|
||||
if token is None:
|
||||
token = await self.base_token
|
||||
async with self._filter([topic]), self._receive_stream(
|
||||
filters.chain(
|
||||
filters.chain(
|
||||
filters.chain(
|
||||
filters.cmd(protocol.SendMessageCommand),
|
||||
lambda c: c if c.token == token else None,
|
||||
),
|
||||
lambda c: (c if c.topic == topic else None),
|
||||
),
|
||||
filter,
|
||||
)
|
||||
) as stream:
|
||||
yield stream
|
||||
|
||||
async def ack(self, command: protocol.SendMessageCommand, status: int = 0):
|
||||
await self._send(
|
||||
protocol.SendMessageAck(status=status, token=command.token, id=command.id)
|
||||
)
|
||||
|
||||
async def expect_notification(
|
||||
self,
|
||||
topic: str,
|
||||
token: typing.Optional[bytes] = None,
|
||||
filter: filters.Filter[
|
||||
protocol.SendMessageCommand, protocol.SendMessageCommand
|
||||
] = filters.ALL,
|
||||
) -> protocol.SendMessageCommand:
|
||||
async with self.notification_stream(topic, token, filter) as stream:
|
||||
command = await stream.receive()
|
||||
await self.ack(command)
|
||||
return command
|
||||
|
@ -2,7 +2,7 @@ from dataclasses import dataclass
|
||||
from hashlib import sha1
|
||||
from typing import Optional, Union
|
||||
|
||||
from anyio.abc import ByteStream, ObjectStream
|
||||
from anyio.abc import ObjectStream
|
||||
|
||||
from pypush.apns._protocol import command, fid
|
||||
from pypush.apns.transport import Packet
|
||||
@ -87,12 +87,7 @@ class FilterCommand(Command):
|
||||
|
||||
def _lookup_hashes(self, hashes: Optional[list[bytes]]):
|
||||
return (
|
||||
[
|
||||
KNOWN_TOPICS_LOOKUP[hash] if hash in KNOWN_TOPICS_LOOKUP else hash
|
||||
for hash in hashes
|
||||
]
|
||||
if hashes
|
||||
else []
|
||||
[KNOWN_TOPICS_LOOKUP.get(hash, hash) for hash in hashes] if hashes else []
|
||||
)
|
||||
|
||||
@property
|
||||
@ -140,6 +135,7 @@ class KeepAliveAck(Command):
|
||||
PacketType = Packet.Type.KeepAliveAck
|
||||
unknown: Optional[int] = fid(1)
|
||||
|
||||
|
||||
@command
|
||||
@dataclass
|
||||
class SetStateCommand(Command):
|
||||
@ -182,7 +178,7 @@ class SendMessageCommand(Command):
|
||||
) and not (self._token_topic_1 is not None and self._token_topic_2 is not None):
|
||||
raise ValueError("topic, token, and outgoing must be set.")
|
||||
|
||||
if self.outgoing == True:
|
||||
if self.outgoing is True:
|
||||
assert self.topic and self.token
|
||||
self._token_topic_1 = (
|
||||
sha1(self.topic.encode()).digest()
|
||||
@ -190,7 +186,7 @@ class SendMessageCommand(Command):
|
||||
else self.topic
|
||||
)
|
||||
self._token_topic_2 = self.token
|
||||
elif self.outgoing == False:
|
||||
elif self.outgoing is False:
|
||||
assert self.topic and self.token
|
||||
self._token_topic_1 = self.token
|
||||
self._token_topic_2 = (
|
||||
@ -201,18 +197,14 @@ class SendMessageCommand(Command):
|
||||
else:
|
||||
assert self._token_topic_1 and self._token_topic_2
|
||||
if len(self._token_topic_1) == 20: # SHA1 hash, topic
|
||||
self.topic = (
|
||||
KNOWN_TOPICS_LOOKUP[self._token_topic_1]
|
||||
if self._token_topic_1 in KNOWN_TOPICS_LOOKUP
|
||||
else self._token_topic_1
|
||||
self.topic = KNOWN_TOPICS_LOOKUP.get(
|
||||
self._token_topic_1, self._token_topic_1
|
||||
)
|
||||
self.token = self._token_topic_2
|
||||
self.outgoing = True
|
||||
else:
|
||||
self.topic = (
|
||||
KNOWN_TOPICS_LOOKUP[self._token_topic_2]
|
||||
if self._token_topic_2 in KNOWN_TOPICS_LOOKUP
|
||||
else self._token_topic_2
|
||||
self.topic = KNOWN_TOPICS_LOOKUP.get(
|
||||
self._token_topic_2, self._token_topic_2
|
||||
)
|
||||
self.token = self._token_topic_1
|
||||
self.outgoing = False
|
||||
@ -229,6 +221,27 @@ class SendMessageAck(Command):
|
||||
unknown6: Optional[bytes] = fid(6, default=None)
|
||||
|
||||
|
||||
@command
|
||||
@dataclass
|
||||
class ScopedTokenCommand(Command):
|
||||
PacketType = Packet.Type.ScopedToken
|
||||
|
||||
token: bytes = fid(1)
|
||||
topic: bytes = fid(2)
|
||||
app_id: Optional[bytes] = fid(3, default=None)
|
||||
|
||||
|
||||
@command
|
||||
@dataclass
|
||||
class ScopedTokenAck(Command):
|
||||
PacketType = Packet.Type.ScopedTokenAck
|
||||
|
||||
status: int = fid(1)
|
||||
scoped_token: bytes = fid(2)
|
||||
topic: bytes = fid(3)
|
||||
app_id: Optional[bytes] = fid(4, default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnknownCommand(Command):
|
||||
id: Packet.Type
|
||||
@ -240,7 +253,7 @@ class UnknownCommand(Command):
|
||||
|
||||
def to_packet(self) -> Packet:
|
||||
return Packet(id=self.id, fields=self.fields)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
if self.id.value in [29, 30, 32]:
|
||||
return f"UnknownCommand(id={self.id}, fields=[SUPPRESSED])"
|
||||
@ -259,6 +272,8 @@ def command_from_packet(packet: Packet) -> Command:
|
||||
Packet.Type.SetState: SetStateCommand,
|
||||
Packet.Type.SendMessage: SendMessageCommand,
|
||||
Packet.Type.SendMessageAck: SendMessageAck,
|
||||
Packet.Type.ScopedToken: ScopedTokenCommand,
|
||||
Packet.Type.ScopedTokenAck: ScopedTokenAck,
|
||||
# Add other mappings here...
|
||||
}
|
||||
command_class = command_classes.get(packet.id, None)
|
||||
|
@ -30,6 +30,8 @@ class Packet:
|
||||
KeepAlive = 12
|
||||
KeepAliveAck = 13
|
||||
NoStorage = 14
|
||||
ScopedToken = 17
|
||||
ScopedTokenAck = 18
|
||||
SetState = 20
|
||||
UNKNOWN = "Unknown"
|
||||
|
||||
@ -38,20 +40,19 @@ class Packet:
|
||||
obj = object.__new__(cls)
|
||||
obj._value_ = value
|
||||
return obj
|
||||
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
# Handle unknown values
|
||||
instance = cls.UNKNOWN
|
||||
instance._value_ = value # Assign the unknown value
|
||||
return instance
|
||||
|
||||
|
||||
def __str__(self):
|
||||
if self is Packet.Type.UNKNOWN:
|
||||
return f"Unknown({self._value_})"
|
||||
return self.name
|
||||
|
||||
|
||||
id: Type
|
||||
fields: list[Field]
|
||||
|
||||
@ -60,18 +61,25 @@ class Packet:
|
||||
|
||||
|
||||
async def create_courier_connection(
|
||||
sandbox: bool = False,
|
||||
courier: str = "1-courier.push.apple.com",
|
||||
) -> PacketStream:
|
||||
context = ssl.create_default_context()
|
||||
context.set_alpn_protocols(ALPN)
|
||||
|
||||
sni = "courier.sandbox.push.apple.com" if sandbox else "courier.push.apple.com"
|
||||
|
||||
# TODO: Verify courier certificate
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
return PacketStream(
|
||||
await anyio.connect_tcp(
|
||||
courier, COURIER_PORT, ssl_context=context, tls_standard_compatible=False
|
||||
courier,
|
||||
COURIER_PORT,
|
||||
ssl_context=context,
|
||||
tls_standard_compatible=False,
|
||||
tls_hostname=sni,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -1,12 +1,17 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from asyncio import CancelledError
|
||||
|
||||
import anyio
|
||||
import typer
|
||||
from rich.logging import RichHandler
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pypush import apns
|
||||
|
||||
from . import proxy as _proxy
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG, handlers=[RichHandler()], format="%(message)s")
|
||||
logging.basicConfig(level=logging.INFO, handlers=[RichHandler()], format="%(message)s")
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@ -22,12 +27,12 @@ def proxy(
|
||||
|
||||
Attach requires SIP to be disabled and to be running as root
|
||||
"""
|
||||
|
||||
_proxy.main(attach)
|
||||
with contextlib.suppress(CancelledError):
|
||||
_proxy.main(attach)
|
||||
|
||||
|
||||
@app.command()
|
||||
def client(
|
||||
def notifications(
|
||||
topic: Annotated[str, typer.Argument(help="app topic to listen on")],
|
||||
sandbox: Annotated[
|
||||
bool, typer.Option("--sandbox/--production", help="APNs courier to use")
|
||||
@ -36,8 +41,29 @@ def client(
|
||||
"""
|
||||
Connect to the APNs courier and listen for app notifications on the given topic
|
||||
"""
|
||||
typer.echo("Running APNs client")
|
||||
raise NotImplementedError("Not implemented yet")
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
with contextlib.suppress(CancelledError):
|
||||
anyio.run(notifications_async, topic, sandbox)
|
||||
|
||||
|
||||
async def notifications_async(topic: str, sandbox: bool):
|
||||
async with apns.create_apns_connection(
|
||||
*await apns.activate(),
|
||||
courier="1-courier.sandbox.push.apple.com"
|
||||
if sandbox
|
||||
else "1-courier.push.apple.com",
|
||||
) as connection:
|
||||
token = await connection.mint_scoped_token(topic)
|
||||
|
||||
async with connection.notification_stream(topic, token) as stream:
|
||||
logging.info(
|
||||
f"Listening for notifications on topic {topic} ({'sandbox' if sandbox else 'production'})"
|
||||
)
|
||||
logging.info(f"Token: {token.hex()}")
|
||||
|
||||
async for notification in stream:
|
||||
await connection.ack(notification)
|
||||
logging.info(notification.payload.decode())
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -1,6 +1,7 @@
|
||||
import frida
|
||||
import logging
|
||||
|
||||
import frida
|
||||
|
||||
|
||||
def attach_to_apsd() -> frida.core.Session:
|
||||
frida.kill("apsd")
|
||||
|
@ -2,7 +2,6 @@ import datetime
|
||||
import logging
|
||||
import ssl
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
import anyio
|
||||
import anyio.abc
|
||||
@ -12,11 +11,10 @@ from cryptography import x509
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.hashes import SHA256
|
||||
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
|
||||
from cryptography.hazmat.primitives.serialization import Encoding
|
||||
|
||||
# from pypush import apns
|
||||
from pypush.apns import transport
|
||||
from pypush.apns import protocol
|
||||
from pypush.apns import protocol, transport
|
||||
|
||||
from . import _frida
|
||||
|
||||
@ -71,7 +69,7 @@ async def handle(client: TLSStream):
|
||||
else "1-courier.sandbox.push.apple.com"
|
||||
)
|
||||
name = f"prod-{connection_cnt}" if not sandbox else f"sandbox-{connection_cnt}"
|
||||
async with await transport.create_courier_connection(forward) as conn:
|
||||
async with await transport.create_courier_connection(sandbox, forward) as conn:
|
||||
logging.debug("Connected to courier")
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(forward_packets, client_pkt, conn, f"client-{name}")
|
||||
|
75
tests/assets/dev.jjtech.pypush.tests.pem
Normal file
75
tests/assets/dev.jjtech.pypush.tests.pem
Normal file
@ -0,0 +1,75 @@
|
||||
Bag Attributes
|
||||
friendlyName: Apple Sandbox Push Services: dev.jjtech.pypush.tests
|
||||
localKeyID: 0A C9 4D 65 F1 39 44 73 5F A8 05 BC B9 00 47 14 2C 12 9A F3
|
||||
subject=UID=dev.jjtech.pypush.tests, CN=Apple Sandbox Push Services: dev.jjtech.pypush.tests, OU=C4492JYJR3, C=US
|
||||
issuer=CN=Apple Worldwide Developer Relations Certification Authority, OU=G4, O=Apple Inc., C=US
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIGnzCCBYegAwIBAgIQRLQgelpeA0ozi3PDbx2ZmTANBgkqhkiG9w0BAQsFADB1
|
||||
MUQwQgYDVQQDDDtBcHBsZSBXb3JsZHdpZGUgRGV2ZWxvcGVyIFJlbGF0aW9ucyBD
|
||||
ZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTELMAkGA1UECwwCRzQxEzARBgNVBAoMCkFw
|
||||
cGxlIEluYy4xCzAJBgNVBAYTAlVTMB4XDTI0MDUxNjAwMTUwM1oXDTI1MDYxNTAw
|
||||
MTUwMlowgYoxJzAlBgoJkiaJk/IsZAEBDBdkZXYuamp0ZWNoLnB5cHVzaC50ZXN0
|
||||
czE9MDsGA1UEAww0QXBwbGUgU2FuZGJveCBQdXNoIFNlcnZpY2VzOiBkZXYuamp0
|
||||
ZWNoLnB5cHVzaC50ZXN0czETMBEGA1UECwwKQzQ0OTJKWUpSMzELMAkGA1UEBhMC
|
||||
VVMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQD3BvhGnrBtXpVLVvdi
|
||||
HFHYeu58MKBD/vyw3A+a4PXnCXskSdEZDydBXJnKa1OeIqn/7TG5/6iiWGR+pcYa
|
||||
XK6kCka8fxpuWgk4/H7C2EN9Atv/XgJit3RSUFdKVN1dvG5cDX5yvFcu7xSt8J+Y
|
||||
RHuqM2YGwor1bZNUCi46n144dntB9rEV2ZgLwrHc2ofo/STbdstGKMJHkhg0GVcI
|
||||
0IzGderz1Ga1UXB8yhr+CvQthjcm74G+aQJZfuMsGwXI06wbKOJQPtCPdAD0taBW
|
||||
rdHivETxRw3WhPzmiwQLUruOmXEo5+bgl1NhnPCLJn374LWaxQEzpnW2HhP6p8mC
|
||||
TzZhAgMBAAGjggMTMIIDDzAMBgNVHRMBAf8EAjAAMB8GA1UdIwQYMBaAFFvZ+h3n
|
||||
mhoLo5l2IlCGPpHIW3eoMHAGCCsGAQUFBwEBBGQwYjAtBggrBgEFBQcwAoYhaHR0
|
||||
cDovL2NlcnRzLmFwcGxlLmNvbS93d2RyZzQuZGVyMDEGCCsGAQUFBzABhiVodHRw
|
||||
Oi8vb2NzcC5hcHBsZS5jb20vb2NzcDAzLXd3ZHJnNDAzMIIBHgYDVR0gBIIBFTCC
|
||||
AREwggENBgkqhkiG92NkBQEwgf8wgcMGCCsGAQUFBwICMIG2DIGzUmVsaWFuY2Ug
|
||||
b24gdGhpcyBjZXJ0aWZpY2F0ZSBieSBhbnkgcGFydHkgYXNzdW1lcyBhY2NlcHRh
|
||||
bmNlIG9mIHRoZSB0aGVuIGFwcGxpY2FibGUgc3RhbmRhcmQgdGVybXMgYW5kIGNv
|
||||
bmRpdGlvbnMgb2YgdXNlLCBjZXJ0aWZpY2F0ZSBwb2xpY3kgYW5kIGNlcnRpZmlj
|
||||
YXRpb24gcHJhY3RpY2Ugc3RhdGVtZW50cy4wNwYIKwYBBQUHAgEWK2h0dHBzOi8v
|
||||
d3d3LmFwcGxlLmNvbS9jZXJ0aWZpY2F0ZWF1dGhvcml0eS8wEwYDVR0lBAwwCgYI
|
||||
KwYBBQUHAwIwMgYDVR0fBCswKTAnoCWgI4YhaHR0cDovL2NybC5hcHBsZS5jb20v
|
||||
d3dkcmc0LTMuY3JsMB0GA1UdDgQWBBQKyU1l8TlEc1+oBby5AEcULBKa8zAOBgNV
|
||||
HQ8BAf8EBAMCB4Awgb8GCiqGSIb3Y2QGAwYEgbAwga0MF2Rldi5qanRlY2gucHlw
|
||||
dXNoLnRlc3RzMAcMBXRvcGljDBxkZXYuamp0ZWNoLnB5cHVzaC50ZXN0cy52b2lw
|
||||
MAYMBHZvaXAMJGRldi5qanRlY2gucHlwdXNoLnRlc3RzLmNvbXBsaWNhdGlvbjAO
|
||||
DAxjb21wbGljYXRpb24MIGRldi5qanRlY2gucHlwdXNoLnRlc3RzLnZvaXAtcHR0
|
||||
MAsMCS52b2lwLXB0dDAQBgoqhkiG92NkBgMBBAIFADANBgkqhkiG9w0BAQsFAAOC
|
||||
AQEAwQac2q1BMnAH1vdZgfDunc+b7SKO6rJIG6w/wl4211YyNBBS5oabQnQDfB8y
|
||||
8iOeWnoWXry60gI2fwWN/rRaQn4QCy72jNeTGz/T/s2jwoGj89114JjcBhRAHvQl
|
||||
/HN4QjSt5rWVRcxTE4cKKbJIqVCm7Uq9VROgbxXrmsZsRnyk1ASvLGboibtGbmty
|
||||
wmXZWns5NXNDbv1wP+PF5HSFXtDWodPYnhvzJe0s9lRvo4yGAt1KL5mNaZM3kKp0
|
||||
74kdzKK/iT7954EQK4ZWPQbDnS1A+/BzHQjK0rWTwjDQkbKvNE9bb+KJbNHH3+DX
|
||||
5s0ybZYoG5meGKUplwu7A2bfFw==
|
||||
-----END CERTIFICATE-----
|
||||
Bag Attributes
|
||||
friendlyName: Apple Sandbox Push Services: dev.jjtech.pypush.tests Private Key
|
||||
localKeyID: 0A C9 4D 65 F1 39 44 73 5F A8 05 BC B9 00 47 14 2C 12 9A F3
|
||||
Key Attributes: <No Attributes>
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQD3BvhGnrBtXpVL
|
||||
VvdiHFHYeu58MKBD/vyw3A+a4PXnCXskSdEZDydBXJnKa1OeIqn/7TG5/6iiWGR+
|
||||
pcYaXK6kCka8fxpuWgk4/H7C2EN9Atv/XgJit3RSUFdKVN1dvG5cDX5yvFcu7xSt
|
||||
8J+YRHuqM2YGwor1bZNUCi46n144dntB9rEV2ZgLwrHc2ofo/STbdstGKMJHkhg0
|
||||
GVcI0IzGderz1Ga1UXB8yhr+CvQthjcm74G+aQJZfuMsGwXI06wbKOJQPtCPdAD0
|
||||
taBWrdHivETxRw3WhPzmiwQLUruOmXEo5+bgl1NhnPCLJn374LWaxQEzpnW2HhP6
|
||||
p8mCTzZhAgMBAAECggEBAKADb8eu+3GdFvAagVyYI5wq5Vik1uu0vFKD+cfFeQQT
|
||||
bCTxe/TTkAYSybwJEb0Zjy0spE1rgfzHbTFsiIqDBs1TqsZnPuPEhrzXMfVcyTqt
|
||||
I3yjlMAFPeAkEqcfmdUiPgp64zHHNmI8lBSoDXlAwypY6PnwArtAI3MItTFcElhX
|
||||
gWB44xVGuJRjRP4UVqXg0ML/Ic2yuYT9DRsDRilYhm8RGRSHkdZKdzCicMZcLtC7
|
||||
bs6/evmIrk9V5AzF6YiXlfT0dOp6yy9mFwhLljXF3Z2/LdrOTAmhLPQRMbUrJrcW
|
||||
ZPd0kMybGIlEoprQEA/6nZkdtIiDo2OJtufCs8g+nJECgYEA/+v4uTJzEI1igKOB
|
||||
myJtADECZAsJUaJaKSAM7VHn1hNOKgNLhUHOuroWvIWEhEomWeMvCbZIG42eOwNW
|
||||
BXGtG7ruT79E6655dljU6E/029FaxONqXXCTD9ZPh031R293KcydMwgBJJ0pvFJE
|
||||
14HWmMRAG0auPygMRhXubXU1ndMCgYEA9xpNWrl9poTjsZDNqvu60nYcq0W1escw
|
||||
ovmb87uxZ5u8fC8T1F3AVMYj4v0dTyA4F0mZenY+nri/hJBuanWVxa5Liu0fGnBr
|
||||
tEa2rzCMaajoDTNMKSygFz6CIMZbbZhozy0+9DHcRcC6b2UtIgB/+/ZQtrTvQ8Ea
|
||||
i6viarkq1nsCgYBznYAM8mynEqhoYvV/RyslBf8FgTLhjU3b/F26rODmhmwucLSi
|
||||
a9tf4ge5fTwjo3f17btnUND8mZrdICGxbex9dZKJtmgFbRn0TCdLGCwPTmIKRo7b
|
||||
zaqyYeglwSNI9WNJH+X4kuopR1L+f9AX59ExzJ8Fc4XuhEIfO3MuQeBJ/wKBgQDa
|
||||
8AgH0X/+EZJ42rcPvxiprxL5wbrpPSHf1M+T5gJqrXcUhNXJ/QMTWbekP+Y/HGn2
|
||||
YDTHZ4tWMJUoTJw4YVTBoQu33R8I2wDi6yCkGpzeZVStlXzuomZ6Ed1UUsvhT//V
|
||||
SN6VmLP1ba0CVB/oF49OXNDpAWlZm/f8NuBW9Rd6jwKBgQDi495IOjLJ8SvWRJLT
|
||||
c9AUmO7IVgipWvr51cF9IYxkzXIVIQIh1usy2NsrBxshAD+FbbWFVBfoptdKBZVK
|
||||
J8u+Ou4gTxs8SdGKGZWZpUMEKJbPsq8lE2aU3mBXiWcFRxYpu+n7nKap0Lla/xBD
|
||||
v77FY1M3FxGR6rNqPJQ9rRLFbA==
|
||||
-----END PRIVATE KEY-----
|
@ -1,18 +1,13 @@
|
||||
import pytest
|
||||
from pypush import apns
|
||||
import asyncio
|
||||
|
||||
# from aioapns import *
|
||||
import uuid
|
||||
import anyio
|
||||
|
||||
# from pypush.apns import _util
|
||||
# from pypush.apns import albert, lifecycle, protocol
|
||||
from pypush import apns
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from rich.logging import RichHandler
|
||||
|
||||
from pypush import apns
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG, handlers=[RichHandler()], format="%(message)s")
|
||||
|
||||
|
||||
@ -26,17 +21,43 @@ async def test_activate():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_2():
|
||||
async with apns.create_apns_connection(
|
||||
certificate, key, courier="localhost"
|
||||
) as connection:
|
||||
await connection.receive(
|
||||
apns.protocol.ConnectAck
|
||||
) # Just wait until the initial connection is established. Don't do this in real code plz.
|
||||
async with apns.create_apns_connection(certificate, key) as _:
|
||||
pass
|
||||
|
||||
|
||||
ASSETS_DIR = Path(__file__).parent / "assets"
|
||||
|
||||
|
||||
async def send_test_notification(device_token, payload=b"hello, world"):
|
||||
async with httpx.AsyncClient(
|
||||
cert=str(ASSETS_DIR / "dev.jjtech.pypush.tests.pem"), http2=True
|
||||
) as client:
|
||||
# Use the certificate and key from above
|
||||
response = await client.post(
|
||||
f"https://api.sandbox.push.apple.com/3/device/{device_token}",
|
||||
content=payload,
|
||||
headers={
|
||||
"apns-topic": "dev.jjtech.pypush.tests",
|
||||
"apns-push-type": "alert",
|
||||
"apns-priority": "10",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shorthand():
|
||||
async def test_scoped_token():
|
||||
async with apns.create_apns_connection(
|
||||
*await apns.activate(), courier="localhost"
|
||||
*await apns.activate(), sandbox=True
|
||||
) as connection:
|
||||
await connection.receive(apns.protocol.ConnectAck)
|
||||
token = await connection.mint_scoped_token("dev.jjtech.pypush.tests")
|
||||
|
||||
test_message = f"test-message-{uuid.uuid4().hex}"
|
||||
|
||||
await send_test_notification(token.hex(), test_message.encode())
|
||||
|
||||
await connection.expect_notification(
|
||||
"dev.jjtech.pypush.tests",
|
||||
token,
|
||||
lambda c: c if c.payload == test_message.encode() else None,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user