diff --git a/telepot/aio/__init__.py b/telepot/aio/__init__.py index 17b0b26..f3686ff 100644 --- a/telepot/aio/__init__.py +++ b/telepot/aio/__init__.py @@ -1,11 +1,15 @@ import io import json import time +import atexit +import aiohttp import asyncio import traceback import collections +import async_timeout from concurrent.futures._base import CancelledError from . import helper, api +from ..api import _methodurl, _fileurl from .. import ( _BotBase, flavor, _find_first_key, _isstring, _strip, _rectify, _dismantle_message_identifier, _split_input_media_array @@ -46,11 +50,13 @@ def event_now(self, data): def cancel(self, event): return event.cancel() - def __init__(self, token, loop=None): + def __init__(self, token, loop=None, session=None, timeout=30): super(Bot, self).__init__(token) self._loop = loop or asyncio.get_event_loop() - api._loop = self._loop # sync loop with api module + self.refresh_session(session) + self._timeout = timeout + atexit.register(self.close) self._scheduler = self.Scheduler(self._loop) @@ -59,6 +65,24 @@ def __init__(self, token, loop=None): 'inline_query': helper._create_invoker(self, 'on_inline_query'), 'chosen_inline_result': helper._create_invoker(self, 'on_chosen_inline_result')}) + def refresh_session(self, session=None): + + if session and not session.closed: + self.session = session + return + + self.session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit=10), + loop=self._loop + ) + + def __del__(self): + self.close() + + def close(self): + if not self.session.closed: + self.session.close() + @property def loop(self): return self._loop @@ -74,8 +98,58 @@ def router(self): async def handle(self, msg): await self._router.route(msg) + def _compose_timeout(self, req): + token, method, params, files = req + + if method == 'getUpdates' and params and 'timeout' in params: + # Ensure HTTP timeout is longer than getUpdates timeout + return params['timeout'] + self._timeout + elif files: + # Disable timeout if uploading files. For some reason, the larger the file, + # the longer it takes for the server to respond (after upload is finished). + # It is unclear how long timeout should be. + return None + else: + return self._timeout + + def _transform(self, req, **user_kw): + timeout = self._compose_timeout(req) + data = api._compose_data(req) + url = _methodurl(req, **user_kw) + + kwargs = {'data': data} + kwargs.update(user_kw) + if self.session.closed: + self.refresh_session() + + return self.session.post, (url,), kwargs, timeout + + async def request(self, req, **user_kw): + fn, args, kwargs, timeout = self._transform(req, **user_kw) + + if api._proxy: + kwargs['proxy'] = api._proxy[0] + if len(api._proxy) > 1: + kwargs['proxy_auth'] = aiohttp.BasicAuth(*api._proxy[1]) + + try: + if timeout is None: + async with fn(*args, **kwargs) as r: + return await api._parse(r) + else: + try: + with async_timeout.timeout(timeout): + async with fn(*args, **kwargs) as r: + return await api._parse(r) + + except asyncio.TimeoutError: + raise exception.TelegramError('Response timeout', 504, {}) + + except aiohttp.ClientConnectionError: + raise exception.TelegramError('Connection Error', 400, {}) + async def _api_request(self, method, params=None, files=None, **kwargs): - return await api.request((self._token, method, params, files), **kwargs) + return await self.request((self._token, method, params, files), **kwargs) async def _api_request_with_file(self, method, params, file_key, file_value, **kwargs): if _isstring(file_value): @@ -615,6 +689,12 @@ async def getGameHighScores(self, user_id, game_message_identifier): p.update(_dismantle_message_identifier(game_message_identifier)) return await self._api_request('getGameHighScores', _rectify(p)) + def download(self, req): + session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit=1, force_close=True), + loop=self._loop) + return session, session.get(_fileurl(req), timeout=self._timeout) + async def download_file(self, file_id, dest): """ Download a file to local disk. @@ -626,7 +706,7 @@ async def download_file(self, file_id, dest): try: d = dest if isinstance(dest, io.IOBase) else open(dest, 'wb') - session, request = api.download((self._token, f['file_path'])) + session, request = self.download((self._token, f['file_path'])) async with session: async with request as r: diff --git a/telepot/aio/api.py b/telepot/aio/api.py index a724461..7baa530 100644 --- a/telepot/aio/api.py +++ b/telepot/aio/api.py @@ -1,21 +1,9 @@ -import asyncio import aiohttp -import async_timeout -import atexit import re import json from .. import exception -from ..api import _methodurl, _which_pool, _fileurl, _guess_filename +from ..api import _guess_filename -_loop = asyncio.get_event_loop() - -_pools = { - 'default': aiohttp.ClientSession( - connector=aiohttp.TCPConnector(limit=10), - loop=_loop) -} - -_timeout = 30 _proxy = None # (url, (username, password)) def set_proxy(url, basic_auth=None): @@ -25,36 +13,7 @@ def set_proxy(url, basic_auth=None): else: _proxy = (url, basic_auth) if basic_auth else (url,) -def _close_pools(): - global _pools - for s in _pools.values(): - s.close() - -atexit.register(_close_pools) - -def _create_onetime_pool(): - return aiohttp.ClientSession( - connector=aiohttp.TCPConnector(limit=1, force_close=True), - loop=_loop) - -def _default_timeout(req, **user_kw): - return _timeout - -def _compose_timeout(req, **user_kw): - token, method, params, files = req - - if method == 'getUpdates' and params and 'timeout' in params: - # Ensure HTTP timeout is longer than getUpdates timeout - return params['timeout'] + _default_timeout(req, **user_kw) - elif files: - # Disable timeout if uploading files. For some reason, the larger the file, - # the longer it takes for the server to respond (after upload is finished). - # It is unclear how long timeout should be. - return None - else: - return _default_timeout(req, **user_kw) - -def _compose_data(req, **user_kw): +def _compose_data(req): token, method, params, files = req data = aiohttp.FormData() @@ -77,27 +36,6 @@ def _compose_data(req, **user_kw): return data -def _transform(req, **user_kw): - timeout = _compose_timeout(req, **user_kw) - - data = _compose_data(req, **user_kw) - - url = _methodurl(req, **user_kw) - - name = _which_pool(req, **user_kw) - - if name is None: - session = _create_onetime_pool() - cleanup = session.close # one-time session: remember to close - else: - session = _pools[name] - cleanup = None # reuse: do not close - - kwargs = {'data':data} - kwargs.update(user_kw) - - return session.post, (url,), kwargs, timeout, cleanup - async def _parse(response): try: data = await response.json() @@ -121,35 +59,3 @@ async def _parse(response): # ... or raise generic error raise exception.TelegramError(description, error_code, data) -async def request(req, **user_kw): - fn, args, kwargs, timeout, cleanup = _transform(req, **user_kw) - - if _proxy: - kwargs['proxy'] = _proxy[0] - if len(_proxy) > 1: - kwargs['proxy_auth'] = aiohttp.BasicAuth(*_proxy[1]) - - try: - if timeout is None: - async with fn(*args, **kwargs) as r: - return await _parse(r) - else: - try: - with async_timeout.timeout(timeout): - async with fn(*args, **kwargs) as r: - return await _parse(r) - - except asyncio.TimeoutError: - raise exception.TelegramError('Response timeout', 504, {}) - - except aiohttp.ClientConnectionError: - raise exception.TelegramError('Connection Error', 400, {}) - - finally: - if cleanup: - cleanup() # e.g. closing one-time session - -def download(req): - session = _create_onetime_pool() - return session, session.get(_fileurl(req), timeout=_timeout) - # Caller should close session after download is complete