Skip to content

Commit e96c927

Browse files
committed
feat: add Qwen3 SFT test, remove AI21 and add Claude APIs
- add Qwen3 SFT test - remove AI21 API - add Claude API - update README.md
1 parent 100f473 commit e96c927

File tree

12 files changed

+503
-112
lines changed

12 files changed

+503
-112
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ Replace `<model_path>` with a local directory or a Hugging Face model like `face
301301
- [x] Dataset generation using self-instruction
302302
- [x] Low-precision LoRA fine-tuning and unsupervised fine-tuning
303303
- [x] INT8 low-precision fine-tuning support
304-
- [x] OpenAI, Cohere and AI21 Studio model APIs for dataset generation
304+
- [x] OpenAI, Cohere, and Claude model APIs for dataset generation
305305
- [x] Added fine-tuned checkpoints for some models to the hub
306306
- [x] INT4 LLaMA LoRA fine-tuning demo
307307
- [x] INT4 LLaMA LoRA fine-tuning with INT4 generation

docs/docs/advanced/generate.md

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,23 @@ engine = Davinci("your-api-key")
2626

2727
</TabItem>
2828

29-
<TabItem value="cohere" label="Cohere">
30-
31-
```python
32-
from xturing.model_apis.cohere import Medium
33-
engine = Medium("your-api-key")
34-
```
35-
36-
</TabItem>
37-
<TabItem value="ai21" label="AI21">
38-
39-
```python
40-
from xturing.model_apis.ai21 import J2Grande
41-
engine = J2Grande("your-api-key")
42-
```
43-
44-
</TabItem>
45-
</Tabs>
29+
<TabItem value="cohere" label="Cohere">
30+
31+
```python
32+
from xturing.model_apis.cohere import Medium
33+
engine = Medium("your-api-key")
34+
```
35+
36+
</TabItem>
37+
<TabItem value="claude" label="Claude">
38+
39+
```python
40+
from xturing.model_apis.claude import ClaudeSonnet
41+
engine = ClaudeSonnet("your-api-key")
42+
```
43+
44+
</TabItem>
45+
</Tabs>
4646

4747
## From no data
4848

examples/datasets/create_alpaca_dataset.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@
4242
"#\n",
4343
"# engine = Medium(\"your-api-key\")\n",
4444
"\n",
45-
"# Alternatively, you can use AI21 to generate dataset\n",
45+
"# Alternatively, you can use Claude to generate dataset\n",
4646
"\n",
47-
"# from xturing.model_apis.ai21 import J2Grande\n",
47+
"# from xturing.model_apis.claude import ClaudeSonnet\n",
4848
"#\n",
4949
"# engine = J2Grande(\"your-api-key\")"
5050
],

examples/datasets/create_instruction_dataset_from_files.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@
4646
"#\n",
4747
"# engine = Medium(\"your-api-key\")\n",
4848
"\n",
49-
"# Alternatively, you can use AI21 to generate dataset\n",
49+
"# Alternatively, you can use Claude to generate dataset\n",
5050
"\n",
51-
"# from xturing.model_apis.ai21 import J2Grande\n",
51+
"# from xturing.model_apis.claude import ClaudeSonnet\n",
5252
"#\n",
5353
"# engine = J2Grande(\"your-api-key\")"
5454
]

examples/features/dataset_generation/create_alpaca_dataset.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@
4242
"#\n",
4343
"# engine = Medium(\"your-api-key\")\n",
4444
"\n",
45-
"# Alternatively, you can use AI21 to generate dataset\n",
45+
"# Alternatively, you can use Claude to generate dataset\n",
4646
"\n",
47-
"# from xturing.model_apis.ai21 import J2Grande\n",
47+
"# from xturing.model_apis.claude import ClaudeSonnet\n",
4848
"#\n",
4949
"# engine = J2Grande(\"your-api-key\")"
5050
],

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ keywords = [
4343
dependencies = [
4444
"torch >= 1.9.0",
4545
"pytorch-lightning",
46-
"transformers>=4.53.0",
46+
"transformers>=4.36.0",
4747
"datasets==2.14.5",
4848
"pyarrow >= 8.0.0, < 21.0.0",
4949
"scipy >= 1.0.0",
@@ -54,8 +54,8 @@ dependencies = [
5454
"gradio>=5.31.0",
5555
"click",
5656
"wget",
57-
"ai21",
5857
"cohere",
58+
"anthropic",
5959
"ipywidgets",
6060
"openai >= 0.27.0",
6161
"pydantic >= 1.10.0",

src/xturing/model_apis/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from xturing.model_apis.ai21 import AI21TextGenerationAPI
2-
from xturing.model_apis.ai21 import J2Grande as AI21J2Grande
31
from xturing.model_apis.base import BaseApi, TextGenerationAPI
2+
from xturing.model_apis.claude import ClaudeSonnet, ClaudeTextGenerationAPI
43
from xturing.model_apis.cohere import CohereTextGenerationAPI
54
from xturing.model_apis.cohere import Medium as CohereMedium
65
from xturing.model_apis.openai import ChatGPT as OpenAIChatGPT
@@ -9,8 +8,8 @@
98

109
BaseApi.add_to_registry(OpenAITextGenerationAPI.config_name, OpenAITextGenerationAPI)
1110
BaseApi.add_to_registry(CohereTextGenerationAPI.config_name, CohereTextGenerationAPI)
12-
BaseApi.add_to_registry(AI21TextGenerationAPI.config_name, AI21TextGenerationAPI)
11+
BaseApi.add_to_registry(ClaudeTextGenerationAPI.config_name, ClaudeTextGenerationAPI)
1312
BaseApi.add_to_registry(OpenAIDavinci.config_name, OpenAIDavinci)
1413
BaseApi.add_to_registry(OpenAIChatGPT.config_name, OpenAIChatGPT)
1514
BaseApi.add_to_registry(CohereMedium.config_name, CohereMedium)
16-
BaseApi.add_to_registry(AI21J2Grande.config_name, AI21J2Grande)
15+
BaseApi.add_to_registry(ClaudeSonnet.config_name, ClaudeSonnet)

src/xturing/model_apis/ai21.py

Lines changed: 0 additions & 70 deletions
This file was deleted.

src/xturing/model_apis/claude.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import time
2+
from datetime import datetime
3+
4+
try:
5+
from anthropic import Anthropic
6+
from anthropic import APIConnectionError as AnthropicAPIConnectionError
7+
from anthropic import APIError as AnthropicAPIError
8+
from anthropic import RateLimitError as AnthropicRateLimitError
9+
except ModuleNotFoundError as import_err: # pragma: no cover - optional dependency
10+
Anthropic = None
11+
AnthropicAPIError = (
12+
AnthropicAPIConnectionError
13+
) = AnthropicRateLimitError = Exception
14+
_ANTHROPIC_IMPORT_ERROR = import_err
15+
else: # pragma: no cover - dependency import paths exercised in runtime envs
16+
_ANTHROPIC_IMPORT_ERROR = None
17+
18+
from xturing.model_apis.base import TextGenerationAPI
19+
20+
21+
class ClaudeTextGenerationAPI(TextGenerationAPI):
22+
config_name = "claude"
23+
24+
def __init__(self, model, api_key, request_batch_size=1):
25+
self._ensure_dependency()
26+
super().__init__(
27+
engine=model, api_key=api_key, request_batch_size=request_batch_size
28+
)
29+
self._client = Anthropic(api_key=api_key)
30+
31+
@staticmethod
32+
def _ensure_dependency():
33+
if Anthropic is None:
34+
message = (
35+
"The anthropic SDK is required for ClaudeTextGenerationAPI. "
36+
"Install it with `pip install anthropic`."
37+
)
38+
raise ModuleNotFoundError(message) from _ANTHROPIC_IMPORT_ERROR
39+
40+
def _make_request(self, prompt, max_tokens, temperature, top_p, stop_sequences):
41+
params = {
42+
"model": self.engine,
43+
"max_tokens": max_tokens,
44+
"temperature": temperature,
45+
"messages": [{"role": "user", "content": prompt}],
46+
}
47+
if top_p is not None:
48+
params["top_p"] = top_p
49+
if stop_sequences:
50+
params["stop_sequences"] = stop_sequences
51+
return self._client.messages.create(**params)
52+
53+
@staticmethod
54+
def _render_response(response):
55+
if response is None:
56+
return None
57+
text_chunks = []
58+
for block in getattr(response, "content", []):
59+
if getattr(block, "type", None) == "text":
60+
text_chunks.append(getattr(block, "text", ""))
61+
predicts = {
62+
"choices": [
63+
{
64+
"text": "".join(text_chunks),
65+
"finish_reason": getattr(response, "stop_reason", "eos"),
66+
}
67+
]
68+
}
69+
return predicts
70+
71+
def generate_text(
72+
self,
73+
prompts,
74+
max_tokens,
75+
temperature,
76+
top_p=None,
77+
frequency_penalty=None,
78+
presence_penalty=None,
79+
stop_sequences=None,
80+
logprobs=None,
81+
n=1,
82+
best_of=1,
83+
retries=3,
84+
**kwargs,
85+
):
86+
if not isinstance(prompts, list):
87+
prompts = [prompts]
88+
89+
results = []
90+
for prompt in prompts:
91+
response = None
92+
retry_cnt = 0
93+
backoff_time = 30
94+
while retry_cnt <= retries:
95+
try:
96+
response = self._make_request(
97+
prompt=prompt,
98+
max_tokens=max_tokens,
99+
temperature=temperature,
100+
top_p=top_p,
101+
stop_sequences=stop_sequences,
102+
)
103+
break
104+
except (
105+
AnthropicAPIError,
106+
AnthropicAPIConnectionError,
107+
AnthropicRateLimitError,
108+
) as e:
109+
print(f"ClaudeError: {e}.")
110+
print(f"Retrying in {backoff_time} seconds...")
111+
time.sleep(backoff_time)
112+
backoff_time *= 1.5
113+
retry_cnt += 1
114+
115+
data = {
116+
"prompt": prompt,
117+
"response": self._render_response(response),
118+
"created_at": str(datetime.now()),
119+
}
120+
results.append(data)
121+
122+
return results
123+
124+
125+
class ClaudeSonnet(ClaudeTextGenerationAPI):
126+
config_name = "claude_3_sonnet"
127+
128+
def __init__(self, api_key, request_batch_size=1):
129+
super().__init__(
130+
model="claude-3-sonnet-20240229",
131+
api_key=api_key,
132+
request_batch_size=request_batch_size,
133+
)

tests/xturing/model_apis/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)