Skip to content
This repository was archived by the owner on Jun 2, 2019. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 84 additions & 4 deletions telepot/aio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import io
import json
import time
import atexit
import aiohttp
import asyncio
import traceback
import collections
import async_timeout
from concurrent.futures._base import CancelledError
from . import helper, api
from ..api import _methodurl, _fileurl
from .. import (
_BotBase, flavor, _find_first_key, _isstring, _strip, _rectify,
_dismantle_message_identifier, _split_input_media_array
Expand Down Expand Up @@ -46,11 +50,13 @@ def event_now(self, data):
def cancel(self, event):
return event.cancel()

def __init__(self, token, loop=None):
def __init__(self, token, loop=None, session=None, timeout=30):
super(Bot, self).__init__(token)

self._loop = loop or asyncio.get_event_loop()
api._loop = self._loop # sync loop with api module
self.refresh_session(session)
self._timeout = timeout
atexit.register(self.close)

self._scheduler = self.Scheduler(self._loop)

Expand All @@ -59,6 +65,24 @@ def __init__(self, token, loop=None):
'inline_query': helper._create_invoker(self, 'on_inline_query'),
'chosen_inline_result': helper._create_invoker(self, 'on_chosen_inline_result')})

def refresh_session(self, session=None):

if session and not session.closed:
self.session = session
return

self.session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=10),
loop=self._loop
)

def __del__(self):
self.close()

def close(self):
if not self.session.closed:
self.session.close()

@property
def loop(self):
return self._loop
Expand All @@ -74,8 +98,58 @@ def router(self):
async def handle(self, msg):
await self._router.route(msg)

def _compose_timeout(self, req):
token, method, params, files = req

if method == 'getUpdates' and params and 'timeout' in params:
# Ensure HTTP timeout is longer than getUpdates timeout
return params['timeout'] + self._timeout
elif files:
# Disable timeout if uploading files. For some reason, the larger the file,
# the longer it takes for the server to respond (after upload is finished).
# It is unclear how long timeout should be.
return None
else:
return self._timeout

def _transform(self, req, **user_kw):
timeout = self._compose_timeout(req)
data = api._compose_data(req)
url = _methodurl(req, **user_kw)

kwargs = {'data': data}
kwargs.update(user_kw)
if self.session.closed:
self.refresh_session()

return self.session.post, (url,), kwargs, timeout

async def request(self, req, **user_kw):
fn, args, kwargs, timeout = self._transform(req, **user_kw)

if api._proxy:
kwargs['proxy'] = api._proxy[0]
if len(api._proxy) > 1:
kwargs['proxy_auth'] = aiohttp.BasicAuth(*api._proxy[1])

try:
if timeout is None:
async with fn(*args, **kwargs) as r:
return await api._parse(r)
else:
try:
with async_timeout.timeout(timeout):
async with fn(*args, **kwargs) as r:
return await api._parse(r)

except asyncio.TimeoutError:
raise exception.TelegramError('Response timeout', 504, {})

except aiohttp.ClientConnectionError:
raise exception.TelegramError('Connection Error', 400, {})

async def _api_request(self, method, params=None, files=None, **kwargs):
return await api.request((self._token, method, params, files), **kwargs)
return await self.request((self._token, method, params, files), **kwargs)

async def _api_request_with_file(self, method, params, file_key, file_value, **kwargs):
if _isstring(file_value):
Expand Down Expand Up @@ -615,6 +689,12 @@ async def getGameHighScores(self, user_id, game_message_identifier):
p.update(_dismantle_message_identifier(game_message_identifier))
return await self._api_request('getGameHighScores', _rectify(p))

def download(self, req):
session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=1, force_close=True),
loop=self._loop)
return session, session.get(_fileurl(req), timeout=self._timeout)

async def download_file(self, file_id, dest):
"""
Download a file to local disk.
Expand All @@ -626,7 +706,7 @@ async def download_file(self, file_id, dest):
try:
d = dest if isinstance(dest, io.IOBase) else open(dest, 'wb')

session, request = api.download((self._token, f['file_path']))
session, request = self.download((self._token, f['file_path']))

async with session:
async with request as r:
Expand Down
98 changes: 2 additions & 96 deletions telepot/aio/api.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,9 @@
import asyncio
import aiohttp
import async_timeout
import atexit
import re
import json
from .. import exception
from ..api import _methodurl, _which_pool, _fileurl, _guess_filename
from ..api import _guess_filename

_loop = asyncio.get_event_loop()

_pools = {
'default': aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=10),
loop=_loop)
}

_timeout = 30
_proxy = None # (url, (username, password))

def set_proxy(url, basic_auth=None):
Expand All @@ -25,36 +13,7 @@ def set_proxy(url, basic_auth=None):
else:
_proxy = (url, basic_auth) if basic_auth else (url,)

def _close_pools():
global _pools
for s in _pools.values():
s.close()

atexit.register(_close_pools)

def _create_onetime_pool():
return aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=1, force_close=True),
loop=_loop)

def _default_timeout(req, **user_kw):
return _timeout

def _compose_timeout(req, **user_kw):
token, method, params, files = req

if method == 'getUpdates' and params and 'timeout' in params:
# Ensure HTTP timeout is longer than getUpdates timeout
return params['timeout'] + _default_timeout(req, **user_kw)
elif files:
# Disable timeout if uploading files. For some reason, the larger the file,
# the longer it takes for the server to respond (after upload is finished).
# It is unclear how long timeout should be.
return None
else:
return _default_timeout(req, **user_kw)

def _compose_data(req, **user_kw):
def _compose_data(req):
token, method, params, files = req

data = aiohttp.FormData()
Expand All @@ -77,27 +36,6 @@ def _compose_data(req, **user_kw):

return data

def _transform(req, **user_kw):
timeout = _compose_timeout(req, **user_kw)

data = _compose_data(req, **user_kw)

url = _methodurl(req, **user_kw)

name = _which_pool(req, **user_kw)

if name is None:
session = _create_onetime_pool()
cleanup = session.close # one-time session: remember to close
else:
session = _pools[name]
cleanup = None # reuse: do not close

kwargs = {'data':data}
kwargs.update(user_kw)

return session.post, (url,), kwargs, timeout, cleanup

async def _parse(response):
try:
data = await response.json()
Expand All @@ -121,35 +59,3 @@ async def _parse(response):
# ... or raise generic error
raise exception.TelegramError(description, error_code, data)

async def request(req, **user_kw):
fn, args, kwargs, timeout, cleanup = _transform(req, **user_kw)

if _proxy:
kwargs['proxy'] = _proxy[0]
if len(_proxy) > 1:
kwargs['proxy_auth'] = aiohttp.BasicAuth(*_proxy[1])

try:
if timeout is None:
async with fn(*args, **kwargs) as r:
return await _parse(r)
else:
try:
with async_timeout.timeout(timeout):
async with fn(*args, **kwargs) as r:
return await _parse(r)

except asyncio.TimeoutError:
raise exception.TelegramError('Response timeout', 504, {})

except aiohttp.ClientConnectionError:
raise exception.TelegramError('Connection Error', 400, {})

finally:
if cleanup:
cleanup() # e.g. closing one-time session

def download(req):
session = _create_onetime_pool()
return session, session.get(_fileurl(req), timeout=_timeout)
# Caller should close session after download is complete