1+ import asyncio
2+ import rasa
3+ import logging
4+ import json
5+ import inspect
6+ from rasa .core .channels .channel import InputChannel , OutputChannel , UserMessage , QueueOutputChannel
7+ from sanic .request import Request
8+ from sanic import Sanic , Blueprint , response
9+ from asyncio import Queue , CancelledError
10+ from typing import Text , List , Dict , Any , Optional , Callable , Iterable , Awaitable
11+ from rasa .core import utils
12+
13+ logger = logging .getLogger (__name__ )
14+
15+ class RestOutput (OutputChannel ):
16+ def __init__ (self ):
17+ self .messages = []
18+ self .custom_data = None
19+ self .language = None
20+
21+ def set_custom_data (self , custom_data ):
22+ self .custom_data = custom_data
23+
24+ def set_language (self , language ):
25+ self .language = language
26+
27+ @classmethod
28+ def name (cls ):
29+ return "rest"
30+
31+ @staticmethod
32+ def _message (
33+ recipient_id , text = None , image = None , buttons = None , attachment = None , custom = None
34+ ):
35+ """Create a message object that will be stored."""
36+
37+ obj = {
38+ "recipient_id" : recipient_id ,
39+ "text" : text ,
40+ "image" : image ,
41+ "buttons" : buttons ,
42+ "attachment" : attachment ,
43+ "custom" : custom ,
44+ }
45+
46+ # filter out any values that are `None`
47+ return utils .remove_none_values (obj )
48+
49+ def latest_output (self ):
50+ if self .messages :
51+ return self .messages [- 1 ]
52+ else :
53+ return None
54+
55+ async def _persist_message (self , message ) -> None :
56+ self .messages .append (message ) # pytype: disable=bad-return-type
57+
58+ async def send_text_message (
59+ self , recipient_id : Text , text : Text , ** kwargs : Any
60+ ) -> None :
61+ for message_part in text .split ("\n \n " ):
62+ await self ._persist_message (self ._message (recipient_id , text = message_part ))
63+
64+ async def send_image_url (
65+ self , recipient_id : Text , image : Text , ** kwargs : Any
66+ ) -> None :
67+ """Sends an image. Default will just post the url as a string."""
68+
69+ await self ._persist_message (self ._message (recipient_id , image = image ))
70+
71+ async def send_attachment (
72+ self , recipient_id : Text , attachment : Text , ** kwargs : Any
73+ ) -> None :
74+ """Sends an attachment. Default will just post as a string."""
75+
76+ await self ._persist_message (self ._message (recipient_id , attachment = attachment ))
77+
78+ async def send_text_with_buttons (
79+ self ,
80+ recipient_id : Text ,
81+ text : Text ,
82+ buttons : List [Dict [Text , Any ]],
83+ ** kwargs : Any
84+ ) -> None :
85+ await self ._persist_message (
86+ self ._message (recipient_id , text = text , buttons = buttons )
87+ )
88+
89+ async def send_custom_json (
90+ self , recipient_id : Text , json_message : Dict [Text , Any ], ** kwargs : Any
91+ ) -> None :
92+ await self ._persist_message (self ._message (recipient_id , custom = json_message ))
93+
94+ class RestInput (InputChannel ):
95+ @classmethod
96+ def name (cls ):
97+ return "rest"
98+
99+ @staticmethod
100+ async def on_message_wrapper (
101+ on_new_message : Callable [[UserMessage ], Awaitable [None ]],
102+ text : Text ,
103+ queue : Queue ,
104+ sender_id : Text ,
105+ input_channel : Text ,
106+ ) -> None :
107+ collector = QueueOutputChannel (queue )
108+
109+ message = UserMessage (text , collector , sender_id , input_channel = input_channel )
110+ await on_new_message (message )
111+
112+ await queue .put ("DONE" ) # pytype: disable=bad-return-type
113+
114+ async def _extract_sender (self , req : Request ) -> Optional [Text ]:
115+ return req .json .get ("sender" , None )
116+
117+ # noinspection PyMethodMayBeStatic
118+ def _extract_message (self , req : Request ) -> Optional [Text ]:
119+ return req .json .get ("message" , None )
120+
121+ def _extract_input_channel (self , req : Request ) -> Text :
122+ return req .json .get ("input_channel" ) or self .name ()
123+
124+ def _extract_custom_data (self , req : Request ) -> Text :
125+ return req .json .get ("customData" , None )
126+
127+ def stream_response (
128+ self ,
129+ on_new_message : Callable [[UserMessage ], Awaitable [None ]],
130+ text : Text ,
131+ sender_id : Text ,
132+ input_channel : Text ,
133+ ) -> Callable [[Any ], Awaitable [None ]]:
134+ async def stream (resp : Any ) -> None :
135+ q = Queue ()
136+ task = asyncio .ensure_future (
137+ self .on_message_wrapper (
138+ on_new_message , text , q , sender_id , input_channel
139+ )
140+ )
141+ result = None # declare variable up front to avoid pytype error
142+ while True :
143+ result = await q .get ()
144+ if result == "DONE" :
145+ break
146+ else :
147+ await resp .write (json .dumps (result ) + "\n " )
148+ await task
149+
150+ return stream # pytype: disable=bad-return-type
151+
152+ def blueprint (self , on_new_message : Callable [[UserMessage ], Awaitable [None ]]):
153+ custom_webhook = Blueprint (
154+ "custom_webhook_{}" .format (type (self ).__name__ ),
155+ inspect .getmodule (self ).__name__ ,
156+ )
157+
158+ # noinspection PyUnusedLocal
159+ @custom_webhook .route ("/" , methods = ["GET" ])
160+ async def health (request : Request ):
161+ return response .json ({"status" : "ok" })
162+
163+ @custom_webhook .route ("/" , methods = ["POST" ])
164+ async def receive (request : Request ):
165+ sender_id = await self ._extract_sender (request )
166+ text = self ._extract_message (request )
167+ custom_data = self ._extract_custom_data (request )
168+ should_use_stream = rasa .utils .endpoints .bool_arg (
169+ request , "stream" , default = False
170+ )
171+ input_channel = self ._extract_input_channel (request )
172+
173+ output_channel = RestOutput ()
174+ if custom_data :
175+ output_channel .set_custom_data (custom_data )
176+ if "language" in custom_data :
177+ output_channel .set_language (custom_data ["language" ])
178+ if should_use_stream :
179+ return response .stream (
180+ self .stream_response (
181+ on_new_message , text , sender_id , input_channel
182+ ),
183+ content_type = "text/event-stream" ,
184+ )
185+ else :
186+ # noinspection PyBroadException
187+ try :
188+ await on_new_message (
189+ UserMessage (
190+ text , output_channel , sender_id , input_channel = input_channel
191+ )
192+ )
193+ except CancelledError :
194+ logger .error (
195+ "Message handling timed out for "
196+ "user message '{}'." .format (text )
197+ )
198+ except Exception :
199+ logger .exception (
200+ "An exception occured while handling "
201+ "user message '{}'." .format (text )
202+ )
203+ return response .json (output_channel .messages )
204+
205+ return custom_webhook
0 commit comments