diff --git a/api/gpt.py b/api/gpt.py index 88d5f6adc..73524eee0 100644 --- a/api/gpt.py +++ b/api/gpt.py @@ -50,8 +50,11 @@ def __init__(self, input_suffix="\n", output_prefix="output: ", output_suffix="\n\n", - append_output_prefix_to_query=False): + append_output_prefix_to_query=False, + premise_prefix="", + premise_suffix="\n\n"): self.examples = {} + self.premise = "" self.engine = engine self.temperature = temperature self.max_tokens = max_tokens @@ -60,6 +63,8 @@ def __init__(self, self.output_prefix = output_prefix self.output_suffix = output_suffix self.append_output_prefix_to_query = append_output_prefix_to_query + self.premise_prefix = premise_prefix + self.premise_suffix = premise_suffix self.stop = (output_suffix + input_prefix).strip() def add_example(self, ex): @@ -70,6 +75,10 @@ def add_example(self, ex): assert isinstance(ex, Example), "Please create an Example object." self.examples[ex.get_id()] = ex + def set_premise(self, premise): + """Sets a premise on the object. """ + self.premise = premise + def delete_example(self, id): """Delete example with the specific id.""" if id in self.examples: @@ -102,7 +111,11 @@ def get_max_tokens(self): def craft_query(self, prompt): """Creates the query for the API request.""" - q = self.get_prime_text( + if self.premise: + q = self.premise_prefix + self.premise + self.premise_suffix + else: + q = "" + q = q + self.get_prime_text( ) + self.input_prefix + prompt + self.input_suffix if self.append_output_prefix_to_query: q = q + self.output_prefix diff --git a/examples/run_twitter_fiction_app.py b/examples/run_twitter_fiction_app.py new file mode 100644 index 000000000..b5840c4e2 --- /dev/null +++ b/examples/run_twitter_fiction_app.py @@ -0,0 +1,40 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + +from api import GPT, Example, UIConfig +from api import demo_web_app + +PROMPT_EXAMPLE_URL = "https://raw.githubusercontent.com/ml4j/gpt-scrolls/master/tweets/twitter-fiction-prompt.json" +TEMPLATE_EXAMPLE_URL = "https://raw.githubusercontent.com/ml4j/gtp-3-prompt-templates/master/question-answer/default/templates/question_answer_template_2.json" + +import requests +import json + +prompt_example_json = json.loads(requests.get(PROMPT_EXAMPLE_URL).text) +template_json = json.loads(requests.get(TEMPLATE_EXAMPLE_URL).text) + +# Construct GPT object and show some examples +gpt = GPT(engine="davinci", + temperature=1.1, + max_tokens=100, + input_prefix=template_json['questionPrefix'], + input_suffix=template_json['questionSuffix'], + output_prefix=template_json['answerPrefix'], + output_suffix=template_json['answerSuffix'], + append_output_prefix_to_query=False, + premise_prefix=template_json['premisePrefix'], + premise_suffix=template_json['premiseSuffix']) + +gpt.set_premise(prompt_example_json['premise']) + +for example in prompt_example_json['questionsAndAnswers']: + gpt.add_example(Example(example['question'], example['answer'])) + +# Define UI configuration +config = UIConfig(description="Twitter Fiction", + button_text="Generate", + placeholder=prompt_example_json['defaultPromptQuestion']) + + +demo_web_app(gpt, config)