Skip to content

Commit 0f6705b

Browse files
committed
Improve extensibility of DefaultChatOptionsBuilder
- Enable DefaultChatOptionsBuilder to accommodate any other sub types - Introduce generics to support sub types that extend DefaultChatOptionsBuilder - Update builder methods to return the sub type - Make FunctionCallingOptions' builder()'s return type to accommodate sub types which can extend FunctionCallingOptions.Builder
1 parent bdc2778 commit 0f6705b

File tree

5 files changed

+42
-79
lines changed

5 files changed

+42
-79
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ public interface ChatOptions extends ModelOptions {
9494
* {@link ChatOptions}.
9595
* @return Returns a new {@link ChatOptions.Builder}.
9696
*/
97-
static ChatOptions.Builder builder() {
97+
static ChatOptions.Builder<? extends DefaultChatOptionsBuilder> builder() {
9898
return new DefaultChatOptionsBuilder();
9999
}
100100

spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,48 +21,52 @@
2121
/**
2222
* Implementation of {@link ChatOptions.Builder} to create {@link DefaultChatOptions}.
2323
*/
24-
public class DefaultChatOptionsBuilder implements ChatOptions.Builder<DefaultChatOptionsBuilder> {
24+
public class DefaultChatOptionsBuilder<T extends DefaultChatOptionsBuilder<T>> implements ChatOptions.Builder<T> {
2525

2626
private final DefaultChatOptions options = new DefaultChatOptions();
2727

28-
public DefaultChatOptionsBuilder model(String model) {
28+
protected T self() {
29+
return (T) this;
30+
}
31+
32+
public T model(String model) {
2933
this.options.setModel(model);
30-
return this;
34+
return self();
3135
}
3236

33-
public DefaultChatOptionsBuilder frequencyPenalty(Double frequencyPenalty) {
37+
public T frequencyPenalty(Double frequencyPenalty) {
3438
this.options.setFrequencyPenalty(frequencyPenalty);
35-
return this;
39+
return self();
3640
}
3741

38-
public DefaultChatOptionsBuilder maxTokens(Integer maxTokens) {
42+
public T maxTokens(Integer maxTokens) {
3943
this.options.setMaxTokens(maxTokens);
40-
return this;
44+
return self();
4145
}
4246

43-
public DefaultChatOptionsBuilder presencePenalty(Double presencePenalty) {
47+
public T presencePenalty(Double presencePenalty) {
4448
this.options.setPresencePenalty(presencePenalty);
45-
return this;
49+
return self();
4650
}
4751

48-
public DefaultChatOptionsBuilder stopSequences(List<String> stop) {
52+
public T stopSequences(List<String> stop) {
4953
this.options.setStopSequences(stop);
50-
return this;
54+
return self();
5155
}
5256

53-
public DefaultChatOptionsBuilder temperature(Double temperature) {
57+
public T temperature(Double temperature) {
5458
this.options.setTemperature(temperature);
55-
return this;
59+
return self();
5660
}
5761

58-
public DefaultChatOptionsBuilder topK(Integer topK) {
62+
public T topK(Integer topK) {
5963
this.options.setTopK(topK);
60-
return this;
64+
return self();
6165
}
6266

63-
public DefaultChatOptionsBuilder topP(Double topP) {
67+
public T topP(Double topP) {
6468
this.options.setTopP(topP);
65-
return this;
69+
return self();
6670
}
6771

6872
public ChatOptions build() {

spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptions.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,6 @@ public class DefaultFunctionCallingOptions extends DefaultChatOptions implements
4747

4848
private Map<String, Object> context = new HashMap<>();
4949

50-
public static FunctionCallingOptions.Builder builder() {
51-
return new DefaultFunctionCallingOptionsBuilder();
52-
}
53-
5450
@Override
5551
public List<FunctionCallback> getFunctionCallbacks() {
5652
return Collections.unmodifiableList(this.functionCallbacks);

spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilder.java

Lines changed: 11 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.Map;
2323
import java.util.Set;
2424

25+
import org.springframework.ai.chat.prompt.DefaultChatOptionsBuilder;
2526
import org.springframework.util.Assert;
2627

2728
/**
@@ -31,88 +32,50 @@
3132
* @author Thomas Vitale
3233
* @author Ilayaperumal Gopinathan
3334
*/
34-
public class DefaultFunctionCallingOptionsBuilder implements FunctionCallingOptions.Builder {
35+
public class DefaultFunctionCallingOptionsBuilder
36+
extends DefaultChatOptionsBuilder<DefaultFunctionCallingOptionsBuilder>
37+
implements FunctionCallingOptions.Builder<DefaultFunctionCallingOptionsBuilder> {
3538

3639
private final DefaultFunctionCallingOptions functionCallingOptions = new DefaultFunctionCallingOptions();
3740

38-
public FunctionCallingOptions.Builder model(String model) {
39-
this.functionCallingOptions.setModel(model);
40-
return this;
41-
}
42-
43-
public FunctionCallingOptions.Builder frequencyPenalty(Double frequencyPenalty) {
44-
this.functionCallingOptions.setFrequencyPenalty(frequencyPenalty);
45-
return this;
46-
}
47-
48-
public FunctionCallingOptions.Builder maxTokens(Integer maxTokens) {
49-
this.functionCallingOptions.setMaxTokens(maxTokens);
50-
return this;
51-
}
52-
53-
public FunctionCallingOptions.Builder presencePenalty(Double presencePenalty) {
54-
this.functionCallingOptions.setPresencePenalty(presencePenalty);
55-
return this;
56-
}
57-
58-
public FunctionCallingOptions.Builder stopSequences(List<String> stopSequences) {
59-
this.functionCallingOptions.setStopSequences(stopSequences);
60-
return this;
61-
}
62-
63-
public FunctionCallingOptions.Builder temperature(Double temperature) {
64-
this.functionCallingOptions.setTemperature(temperature);
65-
return this;
66-
}
67-
68-
public FunctionCallingOptions.Builder topK(Integer topK) {
69-
this.functionCallingOptions.setTopK(topK);
70-
return this;
71-
}
72-
73-
public FunctionCallingOptions.Builder topP(Double topP) {
74-
this.functionCallingOptions.setTopP(topP);
75-
return this;
76-
}
77-
78-
public FunctionCallingOptions.Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
41+
public DefaultFunctionCallingOptionsBuilder functionCallbacks(List<FunctionCallback> functionCallbacks) {
7942
this.functionCallingOptions.setFunctionCallbacks(functionCallbacks);
8043
return this;
8144
}
8245

83-
public FunctionCallingOptions.Builder functionCallbacks(FunctionCallback... functionCallbacks) {
46+
public DefaultFunctionCallingOptionsBuilder functionCallbacks(FunctionCallback... functionCallbacks) {
8447
Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null");
8548
this.functionCallingOptions.setFunctionCallbacks(List.of(functionCallbacks));
8649
return this;
8750
}
8851

89-
public FunctionCallingOptions.Builder functions(Set<String> functions) {
52+
public DefaultFunctionCallingOptionsBuilder functions(Set<String> functions) {
9053
this.functionCallingOptions.setFunctions(functions);
9154
return this;
9255
}
9356

94-
public FunctionCallingOptions.Builder function(String function) {
57+
public DefaultFunctionCallingOptionsBuilder function(String function) {
9558
Assert.notNull(function, "Function must not be null");
9659
var set = new HashSet<>(this.functionCallingOptions.getFunctions());
9760
set.add(function);
9861
this.functionCallingOptions.setFunctions(set);
9962
return this;
10063
}
10164

102-
public FunctionCallingOptions.Builder proxyToolCalls(Boolean proxyToolCalls) {
65+
public DefaultFunctionCallingOptionsBuilder proxyToolCalls(Boolean proxyToolCalls) {
10366
this.functionCallingOptions.setProxyToolCalls(proxyToolCalls);
10467
return this;
10568
}
10669

107-
public FunctionCallingOptions.Builder toolContext(Map<String, Object> context) {
70+
public DefaultFunctionCallingOptionsBuilder toolContext(Map<String, Object> context) {
10871
Assert.notNull(context, "Tool context must not be null");
10972
Map<String, Object> newContext = new HashMap<>(this.functionCallingOptions.getToolContext());
11073
newContext.putAll(context);
11174
this.functionCallingOptions.setToolContext(newContext);
11275
return this;
11376
}
11477

115-
public FunctionCallingOptions.Builder toolContext(String key, Object value) {
78+
public DefaultFunctionCallingOptionsBuilder toolContext(String key, Object value) {
11679
Assert.notNull(key, "Key must not be null");
11780
Assert.notNull(value, "Value must not be null");
11881
Map<String, Object> newContext = new HashMap<>(this.functionCallingOptions.getToolContext());

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public interface FunctionCallingOptions extends ChatOptions {
3535
* @return Returns {@link DefaultFunctionCallingOptionsBuilder} to create a new
3636
* instance of {@link FunctionCallingOptions}.
3737
*/
38-
static FunctionCallingOptions.Builder builder() {
38+
static FunctionCallingOptions.Builder<? extends FunctionCallingOptions.Builder> builder() {
3939
return new DefaultFunctionCallingOptionsBuilder();
4040
}
4141

@@ -87,57 +87,57 @@ default void setProxyToolCalls(Boolean proxyToolCalls) {
8787
/**
8888
* Builder for creating {@link FunctionCallingOptions} instance.
8989
*/
90-
interface Builder extends ChatOptions.Builder<Builder> {
90+
interface Builder<T extends Builder<T>> extends ChatOptions.Builder<T> {
9191

9292
/**
9393
* The list of Function Callbacks to be registered with the Chat model.
9494
* @param functionCallbacks the list of Function Callbacks.
9595
* @return the FunctionCallOptions Builder.
9696
*/
97-
Builder functionCallbacks(List<FunctionCallback> functionCallbacks);
97+
T functionCallbacks(List<FunctionCallback> functionCallbacks);
9898

9999
/**
100100
* The Function Callbacks to be registered with the Chat model.
101101
* @param functionCallbacks the function callbacks.
102102
* @return the FunctionCallOptions Builder.
103103
*/
104-
Builder functionCallbacks(FunctionCallback... functionCallbacks);
104+
T functionCallbacks(FunctionCallback... functionCallbacks);
105105

106106
/**
107107
* {@link Set} of function names to be registered with the Chat model.
108108
* @param functions the {@link Set} of function names
109109
* @return the FunctionCallOptions Builder.
110110
*/
111-
Builder functions(Set<String> functions);
111+
T functions(Set<String> functions);
112112

113113
/**
114114
* The function name to be registered with the chat model.
115115
* @param function the name of the function.
116116
* @return the FunctionCallOptions Builder.
117117
*/
118-
Builder function(String function);
118+
T function(String function);
119119

120120
/**
121121
* Boolean flag to indicate if the proxy ToolCalls is enabled.
122122
* @param proxyToolCalls boolean value to enable proxy ToolCalls.
123123
* @return the FunctionCallOptions Builder.
124124
*/
125-
Builder proxyToolCalls(Boolean proxyToolCalls);
125+
T proxyToolCalls(Boolean proxyToolCalls);
126126

127127
/**
128128
* Add a {@link Map} of context values into tool context.
129129
* @param context the map representing the tool context.
130130
* @return the FunctionCallOptions Builder.
131131
*/
132-
Builder toolContext(Map<String, Object> context);
132+
T toolContext(Map<String, Object> context);
133133

134134
/**
135135
* Add a specific key/value pair to the tool context.
136136
* @param key the key to use.
137137
* @param value the corresponding value.
138138
* @return the FunctionCallOptions Builder.
139139
*/
140-
Builder toolContext(String key, Object value);
140+
T toolContext(String key, Object value);
141141

142142
/**
143143
* Builds the {@link FunctionCallingOptions}.

0 commit comments

Comments
 (0)