Skip to content

Add Schema builder for Structured Outputs #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions lib/ruby_llm/structured_output.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# frozen_string_literal: true

module RubyLLM
module StructuredOutput
class Schema
MAX_OBJECT_PROPERTIES = 100
MAX_NESTING_DEPTH = 5

class << self
def string(name = nil, enum: nil, description: nil)
schema = { type: 'string', enum: enum, description: description }.compact
name ? add_property(name, schema) : schema
end

def number(name = nil, description: nil)
schema = { type: 'number', description: description }.compact
name ? add_property(name, schema) : schema
end

def boolean(name = nil, description: nil)
schema = { type: 'boolean', description: description }.compact
name ? add_property(name, schema) : schema
end

def null(name = nil, description: nil)
schema = { type: 'null', description: description }.compact
name ? add_property(name, schema) : schema
end

def object(name = nil, description: nil, &block)
sub_schema = Class.new(Schema)
sub_schema.class_eval(&block)

schema = {
type: 'object',
properties: sub_schema.properties,
required: sub_schema.required,
additionalProperties: false,
description: description
}.compact

name ? add_property(name, schema) : schema
end

def array(name, type = nil, description: nil, &block)
items = if block_given?
collector = SchemaCollector.new
collector.instance_eval(&block)
collector.schemas.first
elsif type.is_a?(Symbol)
case type
when :string, :number, :boolean, :null
send(type)
else
ref(type)
end
else
raise ArgumentError, "Invalid array type: #{type}"
end

add_property(name, {
type: 'array',
description: description,
items: items
}.compact)
end

def any_of(name, description: nil, &block)
collector = SchemaCollector.new
collector.instance_eval(&block)

add_property(name, {
description: description,
anyOf: collector.schemas
}.compact)
end

def ref(schema_name)
{ '$ref' => "#/$defs/#{schema_name}" }
end

def properties
@properties ||= {}
end

def required
@required ||= []
end

def definitions
@definitions ||= {}
end

def define(name, &)
sub_schema = Class.new(Schema)
sub_schema.class_eval(&)

definitions[name] = {
type: 'object',
properties: sub_schema.properties,
required: sub_schema.required
}
end

private

def add_property(name, definition)
properties[name.to_sym] = definition
required << name.to_sym
end
end

# Simple collector that just stores schemas
class SchemaCollector
attr_reader :schemas

def initialize
@schemas = []
end

def method_missing(method_name, ...)
if Schema.respond_to?(method_name)
@schemas << Schema.send(method_name, ...)
else
super
end
end

def respond_to_missing?(method_name, include_private = false)
Schema.respond_to?(method_name) || super
end
end

def initialize(name = nil)
@name = name || self.class.name
validate_schema
end

def json_schema
{
name: @name,
description: 'Schema for the structured response',
schema: {
type: 'object',
properties: self.class.properties,
required: self.class.required,
additionalProperties: false,
strict: true,
'$defs' => self.class.definitions
}
}
end

private

# Validate the schema against defined limits
def validate_schema
properties_count = count_properties(self.class.properties)
raise 'Exceeded maximum number of object properties' if properties_count > MAX_OBJECT_PROPERTIES

max_depth = calculate_max_depth(self.class.properties)
raise 'Exceeded maximum nesting depth' if max_depth > MAX_NESTING_DEPTH
end

# Count the total number of properties in the schema
def count_properties(schema)
return 0 unless schema.is_a?(Hash) && schema[:properties]

count = schema[:properties].size
schema[:properties].each_value do |prop|
count += count_properties(prop)
end
count
end

# Calculate the maximum nesting depth of the schema
def calculate_max_depth(schema, current_depth = 1)
return current_depth unless schema.is_a?(Hash)

if schema[:type] == 'object' && schema[:properties]
child_depths = schema[:properties].values.map do |prop|
calculate_max_depth(prop, current_depth + 1)
end
[current_depth, child_depths.max].compact.max
elsif schema[:items] # For arrays
calculate_max_depth(schema[:items], current_depth + 1)
else
current_depth
end
end

def method_missing(method_name, ...)
if respond_to_missing?(method_name)
send(method_name, ...)
else
super
end
end

def respond_to_missing?(method_name, include_private = false)
%i[string number boolean array object any_of null].include?(method_name) || super
end
end
end
end
129 changes: 129 additions & 0 deletions spec/ruby_llm/structured_output_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# frozen_string_literal: true

require 'spec_helper'

RSpec.describe RubyLLM::StructuredOutput::Schema do # rubocop:disable RSpec/SpecFilePathFormat
describe 'schema definition' do
json_output { schema.json_schema }

let(:schema_class) do
Class.new(described_class) do
string :name, description: "User's name"
number :age
boolean :active

object :address do
string :street
string :city
end

array :tags, :string, description: 'User tags'

array :contacts do
object do
string :email
string :phone
end
end

any_of :status do
string enum: %w[active pending]
null
end

define :location do
string :latitude
string :longitude
end

array :locations, :location
end
end

let(:schema) { schema_class.new }

it 'generates the correct JSON schema' do # rubocop:disable RSpec/ExampleLength,RSpec/MultipleExpectations
expect(json_output).to include(
name: schema_class.name,
description: 'Schema for the structured response'
)

properties = json_output[:schema][:properties]

# Test basic types
expect(properties[:name]).to eq({ type: 'string', description: "User's name" })
expect(properties[:age]).to eq({ type: 'number' })
expect(properties[:active]).to eq({ type: 'boolean' })

# Test nested object
expect(properties[:address]).to include(
type: 'object',
properties: {
street: { type: 'string' },
city: { type: 'string' }
},
required: %i[street city],
additionalProperties: false
)

# Test arrays
expect(properties[:tags]).to eq({
type: 'array',
description: 'User tags',
items: { type: 'string' }
})

expect(properties[:contacts]).to include(
type: 'array',
items: {
type: 'object',
properties: {
email: { type: 'string' },
phone: { type: 'string' }
},
required: %i[email phone],
additionalProperties: false
}
)

# Test any_of
expect(properties[:status]).to include(
anyOf: [
{ type: 'string', enum: %w[active pending] },
{ type: 'null' }
]
)

# Test references
expect(properties[:locations]).to eq({
type: 'array',
items: { '$ref' => '#/$defs/location' }
})

# Test definitions
expect(json_output[:schema]['$defs']).to include(
location: {
type: 'object',
properties: {
latitude: { type: 'string' },
longitude: { type: 'string' }
},
required: %i[latitude longitude]
}
)
end

it 'includes all properties in required array' do
expect(json_output[:schema][:required]).to contain_exactly(
:name, :age, :active, :address, :tags, :contacts, :status, :locations
)
end

it 'enforces schema constraints' do
expect(json_output[:schema]).to include(
additionalProperties: false,
strict: true
)
end
end
end