Skip to content

Commit

Permalink
Add upscale method to stability ai
Browse files Browse the repository at this point in the history
  • Loading branch information
alchaplinsky committed Oct 26, 2023
1 parent ba8d4b8 commit c998e4d
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 1 deletion.
4 changes: 3 additions & 1 deletion lib/gen_ai/image.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def edit(image, prompt, options = {})
client.edit(image, prompt, options)
end

# def upscale; end
def upscale(image, options = {})
client.upscale(image, options)
end

private

Expand Down
8 changes: 8 additions & 0 deletions lib/gen_ai/image/base.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ def generate(...)
def variations(...)
raise NotImplementedError, "#{self.class.name} does not support variations"
end

def edit(...)
raise NotImplementedError, "#{self.class.name} does not support editing"
end

def upscale(...)
raise NotImplementedError, "#{self.class.name} does not support upscaling"
end
end
end
end
22 changes: 22 additions & 0 deletions lib/gen_ai/image/stability_ai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class StabilityAI < Base
DEFAULT_SIZE = '256x256'
API_BASE_URL = 'https://api.stability.ai'
DEFAULT_MODEL = 'stable-diffusion-xl-beta-v2-2-2'
UPSCALE_MODEL = 'stable-diffusion-x4-latent-upscaler'

def initialize(token:, options: {})
build_client(token)
Expand Down Expand Up @@ -39,6 +40,19 @@ def edit(image, prompt, options = {})
)
end

def upscale(image, options = {})
model = options[:model] || UPSCALE_MODEL
url = "/v1/generation/#{model}/image-to-image/upscale"

response = client.post url, build_upscale_body(image, options), multipart: true

build_result(
raw: response,
model: model,
parsed: response['artifacts'].map { |artifact| artifact['base64'] }
)
end

private

def build_client(token)
Expand All @@ -65,6 +79,14 @@ def build_edit_body(image, prompt, options)
params.merge(options)
end

def build_upscale_body(image, options)
w, = size(options)
{
image: File.binread(image),
width: w
}.merge(options)
end

def size(options)
size = options.delete(:size) || DEFAULT_SIZE
size.split('x').map(&:to_i)
Expand Down
56 changes: 56 additions & 0 deletions spec/fixtures/cassettes/stability_ai/image/upscale_default.yml

Large diffs are not rendered by default.

Binary file added spec/fixtures/images/lighthouse_upscaled.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
76 changes: 76 additions & 0 deletions spec/image/stability_ai/upscale_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# frozen_string_literal: true

require 'openai'

RSpec.describe GenAI::Image do
describe 'Stability AI' do
describe '#upscale' do
let(:provider) { :stability_ai }
let(:instance) { described_class.new(provider, token) }
let(:token) { ENV['API_ACCESS_TOKEN'] || 'FAKE_TOKEN' }

let(:cassette) { 'stability_ai/image/upscale_default' }
let(:fixture_file) { 'lighthouse_upscaled' }
let(:original_image) { './spec/fixtures/images/lighthouse.png' }
let(:image_base64) { Base64.encode64(File.read("spec/fixtures/images/#{fixture_file}.png")).gsub("\n", '') }

subject { instance.upscale original_image, size: '512x512' }

it 'creates upscaled version of an image' do
VCR.use_cassette(cassette) do
expect(subject).to be_a(GenAI::Result)
expect(subject.provider).to eq(:stability_ai)

expect(subject.model).to eq('stable-diffusion-x4-latent-upscaler')

expect(subject.value).to be_a(String)
expect(subject.value).to eq(image_base64)

expect(subject.prompt_tokens).to eq(nil)
expect(subject.completion_tokens).to eq(nil)
expect(subject.total_tokens).to eq(nil)
end
end

context 'with options' do
let(:client) { double('GenAI::Api::Client') }

before do
allow(GenAI::Api::Client).to receive(:new).and_return(client)
allow(client).to receive(:post).and_return({ 'artifacts' => [] })
end

context 'with default options' do
subject { instance.upscale(original_image) }

it 'passes options to the client' do
subject

expect(client).to have_received(:post).with(
'/v1/generation/stable-diffusion-x4-latent-upscaler/image-to-image/upscale', {
image: File.binread(original_image),
width: 256
},
multipart: true
)
end
end

context 'with additional options' do
subject { instance.upscale(original_image, size: '512x512') }

it 'passes options to the client' do
subject

expect(client).to have_received(:post).with(
'/v1/generation/stable-diffusion-x4-latent-upscaler/image-to-image/upscale', {
image: File.binread(original_image),
width: 512
}, multipart: true
)
end
end
end
end
end
end

0 comments on commit c998e4d

Please sign in to comment.