Skip to content

Commit 0108208

Browse files
committed
feat(pr): include anthropic in model selection for PR command
1 parent a6c5874 commit 0108208

File tree

2 files changed

+96
-5
lines changed

2 files changed

+96
-5
lines changed

cmd/pr.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,20 @@ var prCmd = &cobra.Command{
5050
var aiProvider PrProvider
5151

5252
providerName := config.GetProvider()
53-
apiKey, err := config.GetAPIKey()
54-
if err != nil {
55-
fmt.Fprintf(os.Stderr, "Error getting API key: %v\n", err)
56-
os.Exit(1)
53+
54+
// API key is not needed for anthropic provider (uses CLI)
55+
var apiKey string
56+
if providerName != "anthropic" {
57+
var err error
58+
apiKey, err = config.GetAPIKey()
59+
if err != nil {
60+
fmt.Fprintf(os.Stderr, "Error getting API key: %v\n", err)
61+
os.Exit(1)
62+
}
5763
}
5864

5965
var model string
60-
if providerName == "copilot" || providerName == "openai" {
66+
if providerName == "copilot" || providerName == "openai" || providerName == "anthropic" {
6167
var err error
6268
model, err = config.GetModel()
6369
if err != nil {
@@ -77,6 +83,10 @@ var prCmd = &cobra.Command{
7783
aiProvider = provider.NewCopilotProviderWithModel(apiKey, model, endpoint)
7884
case "openai":
7985
aiProvider = provider.NewOpenAIProvider(apiKey, model, endpoint)
86+
case "anthropic":
87+
// Get num_suggestions from config
88+
numSuggestions := config.GetNumSuggestions()
89+
aiProvider = provider.NewAnthropicProvider(model, numSuggestions)
8090
default:
8191
// Default to copilot if provider is not set or unknown
8292
aiProvider = provider.NewCopilotProvider(apiKey, endpoint)

internal/provider/anthropic.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,84 @@ func (a *AnthropicProvider) GenerateCommitMessages(ctx context.Context, diff str
109109

110110
return commitMessages, nil
111111
}
112+
113+
func (a *AnthropicProvider) GeneratePRTitle(ctx context.Context, diff string) (string, error) {
114+
titles, err := a.GeneratePRTitles(ctx, diff)
115+
if err != nil {
116+
return "", err
117+
}
118+
if len(titles) == 0 {
119+
return "", fmt.Errorf("no PR titles generated")
120+
}
121+
return titles[0], nil
122+
}
123+
124+
func (a *AnthropicProvider) GeneratePRTitles(ctx context.Context, diff string) ([]string, error) {
125+
if strings.TrimSpace(diff) == "" {
126+
return nil, fmt.Errorf("no diff provided")
127+
}
128+
129+
// Check if claude CLI is available
130+
if _, err := exec.LookPath("claude"); err != nil {
131+
return nil, fmt.Errorf("claude CLI not found in PATH. Please install Claude Code CLI: %w", err)
132+
}
133+
134+
// Build the prompt using PR title template
135+
systemMsg := GetSystemMessage()
136+
userPrompt := GetPRTitlePrompt(diff)
137+
138+
// Modify the prompt to request specific number of suggestions
139+
fullPrompt := fmt.Sprintf("%s\n\nUser request: %s\n\nIMPORTANT: Generate exactly %d pull request titles, one per line. Do not include any other text, explanations, or formatting - just the PR titles.",
140+
systemMsg, userPrompt, a.numSuggestions)
141+
142+
// Execute claude CLI with the specified model
143+
cmd := exec.CommandContext(ctx, "claude", "--model", a.model, "-p", fullPrompt)
144+
145+
output, err := cmd.CombinedOutput()
146+
if err != nil {
147+
return nil, fmt.Errorf("error executing claude CLI: %w\nOutput: %s", err, string(output))
148+
}
149+
150+
// Parse the output - same logic as commit message generation
151+
content := string(output)
152+
lines := strings.Split(content, "\n")
153+
154+
var prTitles []string
155+
for _, line := range lines {
156+
trimmed := strings.TrimSpace(line)
157+
if trimmed == "" {
158+
continue
159+
}
160+
if len(trimmed) > 200 {
161+
continue
162+
}
163+
// Skip markdown formatting or numbered lists
164+
if strings.HasPrefix(trimmed, "#") || strings.HasPrefix(trimmed, "-") || strings.HasPrefix(trimmed, "*") {
165+
parts := strings.SplitN(trimmed, " ", 2)
166+
if len(parts) == 2 {
167+
trimmed = strings.TrimSpace(parts[1])
168+
}
169+
}
170+
// Remove numbered list formatting like "1. " or "1) "
171+
if len(trimmed) > 3 {
172+
if (trimmed[0] >= '0' && trimmed[0] <= '9') && (trimmed[1] == '.' || trimmed[1] == ')') {
173+
trimmed = strings.TrimSpace(trimmed[2:])
174+
}
175+
}
176+
177+
if trimmed != "" {
178+
prTitles = append(prTitles, trimmed)
179+
}
180+
181+
// Stop once we have enough titles
182+
if len(prTitles) >= a.numSuggestions {
183+
break
184+
}
185+
}
186+
187+
if len(prTitles) == 0 {
188+
return nil, fmt.Errorf("no valid PR titles generated from Claude output")
189+
}
190+
191+
return prTitles, nil
192+
}

0 commit comments

Comments
 (0)