diff --git a/Create-Argilla-Dataset.ipynb b/Create-Argilla-Dataset.ipynb new file mode 100644 index 0000000..44746ce --- /dev/null +++ b/Create-Argilla-Dataset.ipynb @@ -0,0 +1,150 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "476ee822-fada-4678-99de-e79aeb45ac08", + "metadata": {}, + "outputs": [], + "source": [ + "!pip3 install datasets argilla sentence-transformers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8e76e0e-5355-4d19-b602-1953dc2d4e5a", + "metadata": {}, + "outputs": [], + "source": [ + "import argilla as rg\n", + "import pandas as pd\n", + "import uuid\n", + "\n", + "from datasets import Dataset, load_dataset\n", + "from numpy import load\n", + "from sentence_transformers import SentenceTransformer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "424c0389-4f5b-4d09-9c3a-d9c3a40dd516", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = pd.read_json(\"./translated_german_alpaca.json\")\n", + "\n", + "dataset[\"id\"] = [str(uuid.uuid4()) for _ in range(len(dataset))]\n", + "dataset[\"metadata\"] = [{\"translation_model\": \"facebook/wmt19-en-de\", \"original_id\": id_}\n", + " for id_ in range(len(dataset))]\n", + "\n", + "ds = Dataset.from_pandas(dataset)\n", + "ds[100]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bdfa59c6-8b44-4677-92bb-99cc9ea516ee", + "metadata": {}, + "outputs": [], + "source": [ + "sbert_model = \"sentence-transformers/paraphrase-multilingual-mpnet-base-v2\"\n", + "\n", + "encoder = SentenceTransformer(sbert_model, device=\"cuda:0\")\n", + "\n", + "ds = ds.map(\n", + " lambda batch: {\n", + " \"vector_instruction\": encoder.encode(batch[\"instruction\"]),\n", + " \"vector_input\": encoder.encode(batch[\"input\"]),\n", + " \"vector_output\": encoder.encode(batch[\"output\"]),\n", + " },\n", + " batch_size=32,\n", + " batched=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f2376970-0594-4158-8357-50c08846bf14", + "metadata": {}, + "outputs": [], + "source": [ + "# create vector dict with three embedded fields, as expected by argilla data model\n", + "ds = ds.map(\n", + " lambda r: {\"vectors\": {\"instruction\": r[\"vector_instruction\"], \"input\": r[\"vector_input\"], \"output\": r[\"vector_output\"]}}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc11f41c-5f72-4dc7-b21a-6ba5aac1408f", + "metadata": {}, + "outputs": [], + "source": [ + "ds = ds.rename_columns({\"instruction\": \"_instruction\", \"input\": \"input\", \"output\": \"output\"})\n", + "records = rg.DatasetForTextClassification.from_datasets(ds, inputs=[\"_instruction\", \"input\", \"output\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0080c482-c3f6-4257-867e-cd184c4e173b", + "metadata": {}, + "outputs": [], + "source": [ + "labels = [\"BAD INSTRUCTION\", \"BAD INPUT\", \"BAD OUTPUT\", \"INAPPROPRIATE\", \"BIASED\", \"ALL GOOD\"]\n", + "\n", + "settings = rg.TextClassificationSettings(label_schema=labels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5f514fa-8d22-4d43-b8d9-a3f47dc1268c", + "metadata": {}, + "outputs": [], + "source": [ + "records.to_datasets().push_to_hub(\"LEL-A/translated_german_alpaca\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5c31174-5ded-47f5-b3d0-7b7ec308acea", + "metadata": {}, + "outputs": [], + "source": [ + "rg.init(\n", + " api_key=\"\",\n", + " api_url=\"https://lel-a-german-alpaca-test.hf.space\"\n", + ")\n", + "rg.log(records=records, name=\"translated_german_alpaca\", batch_size=100)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}