apns: filters: move FilteredStream

This commit is contained in:
JJTech0130 2024-05-23 08:29:51 -04:00
parent c22904a39d
commit 51b04f816d
No known key found for this signature in database
GPG Key ID: 23C92EBCCF8F93D6
3 changed files with 45 additions and 39 deletions

View File

@ -3,9 +3,7 @@ from contextlib import asynccontextmanager
from typing import Generic, TypeVar
import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from . import filters
from anyio.abc import ObjectSendStream
T = TypeVar("T")
@ -44,31 +42,6 @@ class BroadcastStream(Generic[T]):
await send.aclose()
W = TypeVar("W")
F = TypeVar("F")
class FilteredStream(ObjectReceiveStream[F]):
"""
A stream that filters out unwanted items
filter should return None if the item should be filtered out, otherwise it should return the item or a modified version of it
"""
def __init__(self, source: ObjectReceiveStream[W], filter: filters.Filter[W, F]):
self.source = source
self.filter = filter
async def receive(self) -> F:
async for item in self.source:
if (filtered := self.filter(item)) is not None:
return filtered
raise anyio.EndOfStream
async def aclose(self):
await self.source.aclose()
def exponential_backoff(f):
async def wrapper(*args, **kwargs):
backoff = 1

View File

@ -1,6 +1,9 @@
import logging
from typing import Callable, Optional, Type, TypeVar
import anyio
from anyio.abc import ObjectReceiveStream
from pypush.apns import protocol
T1 = TypeVar("T1")
@ -42,3 +45,28 @@ def ALL(c):
def NONE(_):
return None
W = TypeVar("W")
F = TypeVar("F")
class FilteredStream(ObjectReceiveStream[F]):
"""
A stream that filters out unwanted items
filter should return None if the item should be filtered out, otherwise it should return the item or a modified version of it
"""
def __init__(self, source: ObjectReceiveStream[W], filter: Filter[W, F]):
self.source = source
self.filter = filter
async def receive(self) -> F:
async for item in self.source:
if (filtered := self.filter(item)) is not None:
return filtered
raise anyio.EndOfStream
async def aclose(self):
await self.source.aclose()

View File

@ -99,7 +99,9 @@ class Connection:
@_util.exponential_backoff
async def reconnect(self):
async with self._reconnect_lock: # Prevent weird situations where multiple reconnects are happening at once
async with (
self._reconnect_lock
): # Prevent weird situations where multiple reconnects are happening at once
if self._conn is not None:
logging.warning("Closing existing connection")
await self._conn.aclose()
@ -172,7 +174,7 @@ class Connection:
backlog: bool = True,
):
async with self._broadcast.open_stream(backlog) as stream:
yield _util.FilteredStream(stream, filter)
yield filters.FilteredStream(stream, filter)
async def _receive(
self, filter: filters.Filter[protocol.Command, T], backlog: bool = True
@ -234,18 +236,21 @@ class Connection:
):
if token is None:
token = await self.base_token
async with self._filter([topic]), self._receive_stream(
filters.chain(
async with (
self._filter([topic]),
self._receive_stream(
filters.chain(
filters.chain(
filters.cmd(protocol.SendMessageCommand),
lambda c: c if c.token == token else None,
filters.chain(
filters.cmd(protocol.SendMessageCommand),
lambda c: c if c.token == token else None,
),
lambda c: (c if c.topic == topic else None),
),
lambda c: (c if c.topic == topic else None),
),
filter,
)
) as stream:
filter,
)
) as stream,
):
yield stream
async def ack(self, command: protocol.SendMessageCommand, status: int = 0):