Skip to content

Commit 2cf6f91

Browse files
committed
Implement working with_response_format for Open AI
1 parent a0f3704 commit 2cf6f91

19 files changed

+1209
-13
lines changed

.gitignore

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ build-iPhoneSimulator/
4747
# for a library or gem, you might want to ignore these files since the code is
4848
# intended to run in multiple environments; otherwise, check them in:
4949
Gemfile.lock
50-
# .ruby-version
51-
# .ruby-gemset
50+
.ruby-version
51+
.ruby-gemset
5252

5353
# unless supporting rvm < 1.11.0 or doing something fancy, ignore this:
5454
.rvmrc
@@ -57,3 +57,4 @@ Gemfile.lock
5757
# .rubocop-https?--*
5858

5959
repomix-output.*
60+
/.idea/

Gemfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ group :development do
1818
gem 'nokogiri'
1919
gem 'overcommit', '>= 0.66'
2020
gem 'pry', '>= 0.14'
21+
gem 'pry-byebug', '>= 3.11'
2122
gem 'rake', '>= 13.0'
2223
gem 'rdoc'
2324
gem 'reline'

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ chat.ask "Tell me a story about a Ruby programmer" do |chunk|
6060
print chunk.content
6161
end
6262

63+
# Get structured responses easily (OpenAI only for now)
64+
chat.with_response_format(:integer).ask("What is 2 + 2?").to_i # => 4
65+
6366
# Generate images
6467
RubyLLM.paint "a sunset over mountains in watercolor style"
6568

docs/guides/chat.md

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,54 @@ end
261261
chat.ask "What is metaprogramming in Ruby?"
262262
```
263263

264+
## Receiving Structured Responses
265+
You can ensure the responses follow a schema you define like this:
266+
```ruby
267+
chat = RubyLLM.chat
268+
269+
chat.with_response_format(:integer).ask("What is 2 + 2?").to_i
270+
# => 4
271+
272+
chat.with_response_format(:string).ask("Say 'Hello World' and nothing else.").content
273+
# => "Hello World"
274+
275+
chat.with_response_format(:array, items: { type: :string })
276+
chat.ask('What are the 2 largest countries? Only respond with country names.').content
277+
# => ["Russia", "Canada"]
278+
279+
chat.with_response_format(:object, properties: { age: { type: :integer } })
280+
chat.ask('Provide sample customer age between 10 and 100.').content
281+
# => { "age" => 42 }
282+
283+
chat.with_response_format(
284+
:object,
285+
properties: { hobbies: { type: :array, items: { type: :string, enum: %w[Soccer Golf Hockey] } } }
286+
)
287+
chat.ask('Provide at least 1 hobby.').content
288+
# => { "hobbies" => ["Soccer"] }
289+
```
290+
291+
You can also provide the JSON schema you want directly to the method like this:
292+
```ruby
293+
chat.with_response_format(type: :object, properties: { age: { type: :integer } })
294+
# => { "age" => 31 }
295+
```
296+
297+
In this example the code is automatically switching to OpenAI's json_mode since no object properties are requested:
298+
```ruby
299+
chat.with_response_format(:json) # Don't care about structure, just give me JSON
300+
301+
chat.ask('Provide a sample customer data object with name and email keys.').content
302+
# => { "name" => "Tobias", "email" => "[email protected]" }
303+
304+
chat.ask('Provide a sample customer data object with name and email keys.').content
305+
# => { "first_name" => "Michael", "email_address" => "[email protected]" }
306+
```
307+
308+
{: .note }
309+
**Only OpenAI supported for now:** Only OpenAI models support this feature for now. We will add support for other models shortly.
310+
311+
264312
## Next Steps
265313

266314
This guide covered the core `Chat` interface. Now you might want to explore:
@@ -269,4 +317,4 @@ This guide covered the core `Chat` interface. Now you might want to explore:
269317
* [Using Tools]({% link guides/tools.md %}): Enable the AI to call your Ruby code.
270318
* [Streaming Responses]({% link guides/streaming.md %}): Get real-time feedback from the AI.
271319
* [Rails Integration]({% link guides/rails.md %}): Persist your chat conversations easily.
272-
* [Error Handling]({% link guides/error-handling.md %}): Build robust applications that handle API issues.
320+
* [Error Handling]({% link guides/error-handling.md %}): Build robust applications that handle API issues.

lib/ruby_llm/chat.rb

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,55 @@ def initialize(model: nil, provider: nil, assume_model_exists: false, context: n
3131
}
3232
end
3333

34+
##
35+
# This method lets you ensure the responses follow a schema you define like this:
36+
#
37+
# chat.with_response_format(:integer).ask("What is 2 + 2?").to_i
38+
# # => 4
39+
# chat.with_response_format(:string).ask("Say 'Hello World' and nothing else.").content
40+
# # => "Hello World"
41+
# chat.with_response_format(:array, items: { type: :string })
42+
# chat.ask('What are the 2 largest countries? Only respond with country names.').content
43+
# # => ["Russia", "Canada"]
44+
# chat.with_response_format(:object, properties: { age: { type: :integer } })
45+
# chat.ask('Provide sample customer age between 10 and 100.').content
46+
# # => { "age" => 42 }
47+
# chat.with_response_format(
48+
# :object,
49+
# properties: { hobbies: { type: :array, items: { type: :string, enum: %w[Soccer Golf Hockey] } } }
50+
# )
51+
# chat.ask('Provide at least 1 hobby.').content
52+
# # => { "hobbies" => ["Soccer"] }
53+
#
54+
# You can also provide the JSON schema you want directly to the method like this:
55+
# chat.with_response_format(type: :object, properties: { age: { type: :integer } })
56+
# # => { "age" => 31 }
57+
#
58+
# In this example the code is automatically switching to OpenAI's json_mode since no object
59+
# properties are requested:
60+
# chat.with_response_format(:json) # Don't care about structure, just give me JSON
61+
# chat.ask('Provide a sample customer data object with name and email keys.').content
62+
# # => { "name" => "Tobias", "email" => "[email protected]" }
63+
# chat.ask('Provide a sample customer data object with name and email keys.').content
64+
# # => { "first_name" => "Michael", "email_address" => "[email protected]" }
65+
#
66+
# @param type [Symbol] (optional) This can be anything supported by the API JSON schema types (integer, object, etc)
67+
# @param schema [Hash] The schema for the response format. It can be a JSON schema or a simple hash.
68+
# @return [Chat] (self)
69+
def with_response_format(type = nil, **schema)
70+
schema_hash = if type.is_a?(Symbol) || type.is_a?(String)
71+
{ type: type == :json ? :object : type }
72+
elsif type.is_a?(Hash)
73+
type
74+
else
75+
{}
76+
end.merge(schema)
77+
78+
@response_schema = Schema.new(schema_hash)
79+
80+
self
81+
end
82+
3483
def ask(message = nil, with: {}, &block)
3584
add_message role: :user, content: Content.new(message, with)
3685
complete(&block)
@@ -96,14 +145,16 @@ def each(&)
96145

97146
def complete(&) # rubocop:disable Metrics/MethodLength
98147
@on[:new_message]&.call
99-
response = @provider.complete(
100-
messages,
101-
tools: @tools,
102-
temperature: @temperature,
103-
model: @model.id,
104-
connection: @connection,
105-
&
106-
)
148+
response = @provider.with_response_schema(@response_schema) do
149+
@provider.complete(
150+
messages,
151+
tools: @tools,
152+
temperature: @temperature,
153+
model: @model.id,
154+
connection: @connection,
155+
&
156+
)
157+
end
107158
@on[:end_message]&.call(response)
108159

109160
add_message response

lib/ruby_llm/message.rb

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ module RubyLLM
77
class Message
88
ROLES = %i[system user assistant tool].freeze
99

10-
attr_reader :role, :content, :tool_calls, :tool_call_id, :input_tokens, :output_tokens, :model_id
10+
attr_reader :role, :tool_calls, :tool_call_id, :input_tokens, :output_tokens, :model_id
11+
12+
delegate :to_i, :to_a, :to_s, to: :content
1113

1214
def initialize(options = {})
1315
@role = options[:role].to_sym
@@ -17,10 +19,29 @@ def initialize(options = {})
1719
@output_tokens = options[:output_tokens]
1820
@model_id = options[:model_id]
1921
@tool_call_id = options[:tool_call_id]
22+
@schema = options[:schema]
2023

2124
ensure_valid_role
2225
end
2326

27+
def content
28+
return @content unless @schema.present?
29+
30+
if @schema[:type].to_s == :object.to_s && @schema[:properties].to_h.keys.none?
31+
json_response
32+
else
33+
structured_content
34+
end
35+
end
36+
37+
def json_response
38+
JSON.parse(@content)
39+
end
40+
41+
def structured_content
42+
json_response['result']
43+
end
44+
2445
def tool_call?
2546
!tool_calls.nil? && !tool_calls.empty?
2647
end

lib/ruby_llm/provider.rb

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,29 @@ def list_models(connection:)
3131
parse_list_models_response response, slug, capabilities
3232
end
3333

34+
##
35+
# @return [::RubyLLM::Schema, NilClass]
36+
def response_schema
37+
Thread.current[:response_schema]
38+
end
39+
40+
##
41+
# @param response_schema [::RubyLLM::Schema]
42+
def with_response_schema(response_schema)
43+
prev_response_schema = Thread.current[:response_schema]
44+
45+
result = nil
46+
begin
47+
Thread.current[:response_schema] = response_schema
48+
49+
result = yield
50+
ensure
51+
Thread.current[:response_schema] = prev_response_schema
52+
end
53+
54+
result
55+
end
56+
3457
def embed(text, model:, connection:)
3558
payload = render_embedding_payload(text, model:)
3659
response = connection.post embedding_url, payload

lib/ruby_llm/providers/openai/chat.rb

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ def render_payload(messages, tools:, temperature:, model:, stream: false) # rubo
2222
payload[:tools] = tools.map { |_, tool| tool_for(tool) }
2323
payload[:tool_choice] = 'auto'
2424
end
25+
26+
add_response_schema_to_payload(payload) if response_schema.present?
27+
2528
payload[:stream_options] = { include_usage: true } if stream
2629
end
2730
end
@@ -35,6 +38,7 @@ def parse_completion_response(response) # rubocop:disable Metrics/MethodLength
3538

3639
Message.new(
3740
role: :assistant,
41+
schema: response_schema,
3842
content: message_data['content'],
3943
tool_calls: parse_tool_calls(message_data['tool_calls']),
4044
input_tokens: data['usage']['prompt_tokens'],
@@ -62,6 +66,54 @@ def format_role(role)
6266
role.to_s
6367
end
6468
end
69+
70+
private
71+
72+
##
73+
# @param [Hash] payload
74+
def add_response_schema_to_payload(payload)
75+
payload[:response_format] = gen_response_format_request
76+
77+
return unless payload[:response_format][:type] == :json_object
78+
79+
# NOTE: this is required by the Open AI API when requesting arbitrary JSON.
80+
payload[:messages].unshift({ role: :developer, content: <<~GUIDANCE
81+
You must format your output as a valid JSON object.
82+
Format your entire response as valid JSON.
83+
Do not include explanations, markdown formatting, or any text outside the JSON.
84+
GUIDANCE
85+
})
86+
end
87+
88+
##
89+
# @return [Hash]
90+
def gen_response_format_request
91+
if response_schema[:type].to_s == :object.to_s && response_schema[:properties].to_h.keys.none?
92+
{ type: :json_object } # Assume we just want json_mode
93+
else
94+
gen_json_schema_format_request
95+
end
96+
end
97+
98+
def gen_json_schema_format_request # rubocop:disable Metrics/MethodLength -- because it's mostly the standard hash
99+
result_schema = response_schema.dup # so we don't modify the original in the thread
100+
result_schema.add_to_each_object_type!(:additionalProperties, false)
101+
result_schema.add_to_each_object_type!(:required, ->(schema) { schema[:properties].to_h.keys })
102+
103+
{
104+
type: :json_schema,
105+
json_schema: {
106+
name: :response,
107+
schema: {
108+
type: :object,
109+
properties: { result: result_schema.to_h },
110+
additionalProperties: false,
111+
required: [:result]
112+
},
113+
strict: true
114+
}
115+
}
116+
end
65117
end
66118
end
67119
end

lib/ruby_llm/schema.rb

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# frozen_string_literal: true
2+
3+
module RubyLLM
4+
##
5+
# Schema class for defining the structure of data objects.
6+
# Wraps the #Hash class
7+
# @see #Hash
8+
class Schema
9+
delegate_missing_to :@schema
10+
11+
def initialize(schema = {})
12+
@schema = deep_transform_keys_in_object(schema.to_h.dup, &:to_sym)
13+
end
14+
15+
def [](key)
16+
@schema[key.to_sym]
17+
end
18+
19+
def []=(key, new_value)
20+
@schema[key.to_sym] = deep_transform_keys_in_object(new_value, &:to_sym)
21+
end
22+
23+
def add_to_each_object_type!(new_key, new_value)
24+
add_to_each_object_type(new_key, new_value, @schema)
25+
end
26+
27+
def present?
28+
@schema.present? && @schema[:type].present?
29+
end
30+
31+
private
32+
33+
def add_to_each_object_type(new_key, new_value, schema)
34+
return schema unless schema.is_a?(Hash)
35+
36+
if schema[:type].to_s == :object.to_s
37+
add_to_object_type(new_key, new_value, schema)
38+
elsif schema[:type].to_s == :array.to_s && schema[:items]
39+
schema[:items] = add_to_each_object_type(new_key, new_value, schema[:items])
40+
end
41+
42+
schema
43+
end
44+
45+
def add_to_object_type(new_key, new_value, schema)
46+
if schema[new_key.to_sym].nil?
47+
schema[new_key.to_sym] = new_value.is_a?(Proc) ? new_value.call(schema) : new_value
48+
end
49+
50+
schema[:properties]&.transform_values! { |value| add_to_each_object_type(new_key, new_value, value) }
51+
end
52+
53+
##
54+
# Recursively transforms keys in a hash or array to symbols.
55+
# Borrowed from ActiveSupport's Hash#deep_transform_keys
56+
# @param object [Object] The object to transform.
57+
# @param block [Proc] The block to apply to each key.
58+
# @return [Object] The transformed object.
59+
def deep_transform_keys_in_object(object, &block)
60+
case object
61+
when Hash
62+
object.each_with_object({}) do |(key, value), result|
63+
result[yield(key)] = deep_transform_keys_in_object(value, &block)
64+
end
65+
when Array
66+
object.map { |e| deep_transform_keys_in_object(e, &block) }
67+
else
68+
object
69+
end
70+
end
71+
end
72+
end

0 commit comments

Comments
 (0)