Skip to content

Commit

Permalink
REFACTOR: Separation of concerns for embedding generation.
Browse files Browse the repository at this point in the history
In a previous refactor, we moved the responsibility of querying and storing embeddings into the `Schema` class. Now, it's time for embedding generation.

The motivation behind these changes is to isolate vector characteristics in simple objects to later replace them with a DB-backed version, similar to what we did with LLM configs.
  • Loading branch information
romanrizzi committed Dec 13, 2024
1 parent eae527f commit 3b3fac2
Show file tree
Hide file tree
Showing 36 changed files with 375 additions and 496 deletions.
4 changes: 1 addition & 3 deletions app/jobs/regular/generate_embeddings.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ def execute(args)
return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
return if post.raw.blank?

vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation

vector_rep.generate_representation_from(target)
DiscourseAi::Embeddings::Vector.instance.generate_representation_from(target)
end
end
end
4 changes: 2 additions & 2 deletions app/jobs/regular/generate_rag_embeddings.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ class GenerateRagEmbeddings < ::Jobs::Base
def execute(args)
return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty?

vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector = DiscourseAi::Embeddings::Vector.instance

# generate_representation_from checks compares the digest value to make sure
# the embedding is only generated once per fragment unless something changes.
fragments.map { |fragment| vector_rep.generate_representation_from(fragment) }
fragments.map { |fragment| vector.generate_representation_from(fragment) }

last_fragment = fragments.last
target = last_fragment.target
Expand Down
27 changes: 14 additions & 13 deletions app/jobs/scheduled/embeddings_backfill.rb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def execute(args)

rebaked = 0

vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector = DiscourseAi::Embeddings::Vector.instance
vector_def = vector.vdef
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE

topics =
Expand All @@ -30,19 +31,19 @@ def execute(args)
.where(deleted_at: nil)
.order("topics.bumped_at DESC")

rebaked += populate_topic_embeddings(vector_rep, topics.limit(limit - rebaked))
rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked))

return if rebaked >= limit

# Then, we'll try to backfill embeddings for topics that have outdated
# embeddings, be it model or strategy version
relation = topics.where(<<~SQL).limit(limit - rebaked)
#{table_name}.model_version < #{vector_rep.version}
#{table_name}.model_version < #{vector_def.version}
OR
#{table_name}.strategy_version < #{vector_rep.strategy_version}
#{table_name}.strategy_version < #{vector_def.strategy_version}
SQL

rebaked += populate_topic_embeddings(vector_rep, relation)
rebaked += populate_topic_embeddings(vector, relation)

return if rebaked >= limit

Expand All @@ -54,7 +55,7 @@ def execute(args)
.where("#{table_name}.updated_at < topics.updated_at")
.limit((limit - rebaked) / 10)

populate_topic_embeddings(vector_rep, relation, force: true)
populate_topic_embeddings(vector, relation, force: true)

return if rebaked >= limit

Expand All @@ -76,7 +77,7 @@ def execute(args)
.limit(limit - rebaked)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end

Expand All @@ -86,14 +87,14 @@ def execute(args)
# embeddings, be it model or strategy version
posts
.where(<<~SQL)
#{table_name}.model_version < #{vector_rep.version}
#{table_name}.model_version < #{vector_def.version}
OR
#{table_name}.strategy_version < #{strategy.version}
#{table_name}.strategy_version < #{vector_def.strategy_version}
SQL
.limit(limit - rebaked)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end

Expand All @@ -107,7 +108,7 @@ def execute(args)
.limit((limit - rebaked) / 10)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end

Expand All @@ -116,7 +117,7 @@ def execute(args)

private

def populate_topic_embeddings(vector_rep, topics, force: false)
def populate_topic_embeddings(vector, topics, force: false)
done = 0

topics =
Expand All @@ -126,7 +127,7 @@ def populate_topic_embeddings(vector_rep, topics, force: false)
batch_size = 1000

ids.each_slice(batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
vector.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
done += batch.length
end

Expand Down
6 changes: 3 additions & 3 deletions lib/ai_bot/personas/persona.rb
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,10 @@ def rag_fragments_prompt(conversation_context, llm:, user:)

return nil if !consolidated_question

vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector = DiscourseAi::Embeddings::Vector.instance
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings

interactions_vector = vector_rep.vector_from(consolidated_question)
interactions_vector = vector.vector_from(consolidated_question)

rag_conversation_chunks = self.class.rag_conversation_chunks
search_limit =
Expand All @@ -327,7 +327,7 @@ def rag_fragments_prompt(conversation_context, llm:, user:)
rag_conversation_chunks
end

schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep)
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector.vdef)

candidate_fragment_ids =
schema
Expand Down
5 changes: 2 additions & 3 deletions lib/ai_bot/tool_runner.rb
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,10 @@ def rag_search(query, filenames: nil, limit: 10)

return [] if upload_refs.empty?

vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
query_vector = vector_rep.vector_from(query)
query_vector = DiscourseAi::Embeddings::Vector.instance.vector_from(query)
fragment_ids =
DiscourseAi::Embeddings::Schema
.for(RagDocumentFragment, vector: vector_rep)
.for(RagDocumentFragment)
.asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder|
builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool")
rag_document_fragments ON
Expand Down
6 changes: 3 additions & 3 deletions lib/ai_helper/semantic_categorizer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def tags
private

def nearest_neighbors(limit: 100)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
vector = DiscourseAi::Embeddings::Vector.instance
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)

raw_vector = vector_rep.vector_from(@text)
raw_vector = vector.vector_from(@text)

muted_category_ids = nil
if @user.present?
Expand Down
44 changes: 27 additions & 17 deletions lib/embeddings/schema.rb
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,31 @@ class Schema

def self.for(
target_klass,
vector: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_def: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
)
case target_klass&.name
when "Topic"
new(TOPICS_TABLE, "topic_id", vector)
new(TOPICS_TABLE, "topic_id", vector_def)
when "Post"
new(POSTS_TABLE, "post_id", vector)
new(POSTS_TABLE, "post_id", vector_def)
when "RagDocumentFragment"
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector)
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def)
else
raise ArgumentError, "Invalid target type for embeddings"
end
end

def initialize(table, target_column, vector)
def initialize(table, target_column, vector_def)
@table = table
@target_column = target_column
@vector = vector
@vector_def = vector_def
end

attr_reader :table, :target_column, :vector
attr_reader :table, :target_column, :vector_def

def find_by_embedding(embedding)
DB.query(<<~SQL, query_embedding: embedding, vid: vector.id, vsid: vector.strategy_id).first
DB.query(
<<~SQL,
SELECT *
FROM #{table}
WHERE
Expand All @@ -46,10 +47,15 @@ def find_by_embedding(embedding)
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT 1
SQL
query_embedding: embedding,
vid: vector_def.id,
vsid: vector_def.strategy_id,
).first
end

def find_by_target(target)
DB.query(<<~SQL, target_id: target.id, vid: vector.id, vsid: vector.strategy_id).first
DB.query(
<<~SQL,
SELECT *
FROM #{table}
WHERE
Expand All @@ -58,6 +64,10 @@ def find_by_target(target)
#{target_column} = :target_id
LIMIT 1
SQL
target_id: target.id,
vid: vector_def.id,
vsid: vector_def.strategy_id,
).first
end

def asymmetric_similarity_search(embedding, limit:, offset:)
Expand Down Expand Up @@ -87,8 +97,8 @@ def asymmetric_similarity_search(embedding, limit:, offset:)

builder.where(
"model_id = :model_id AND strategy_id = :strategy_id",
model_id: vector.id,
strategy_id: vector.strategy_id,
model_id: vector_def.id,
strategy_id: vector_def.strategy_id,
)

yield(builder) if block_given?
Expand Down Expand Up @@ -156,7 +166,7 @@ def symmetric_similarity_search(record)

yield(builder) if block_given?

builder.query(vid: vector.id, vsid: vector.strategy_id, target_id: record.id)
builder.query(vid: vector_def.id, vsid: vector_def.strategy_id, target_id: record.id)
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
Expand All @@ -176,10 +186,10 @@ def store(record, embedding, digest)
updated_at = :now
SQL
target_id: record.id,
model_id: vector.id,
model_version: vector.version,
strategy_id: vector.strategy_id,
strategy_version: vector.strategy_version,
model_id: vector_def.id,
model_version: vector_def.version,
strategy_id: vector_def.strategy_id,
strategy_version: vector_def.strategy_version,
digest: digest,
embeddings: embedding,
now: Time.zone.now,
Expand All @@ -188,7 +198,7 @@ def store(record, embedding, digest)

private

delegate :dimensions, :pg_function, to: :vector
delegate :dimensions, :pg_function, to: :vector_def
end
end
end
3 changes: 1 addition & 2 deletions lib/embeddings/semantic_related.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ def self.clear_cache_for(topic)
def related_topic_ids_for(topic)
return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1

vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
cache_for = results_ttl(topic)

Discourse
.cache
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
DiscourseAi::Embeddings::Schema
.for(Topic, vector: vector_rep)
.for(Topic)
.symmetric_similarity_search(topic)
.map(&:topic_id)
.tap do |candidate_ids|
Expand Down
18 changes: 8 additions & 10 deletions lib/embeddings/semantic_search.rb
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def cached_query?(query)
Discourse.cache.read(embedding_key).present?
end

def vector_rep
@vector_rep ||= DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
def vector
@vector ||= DiscourseAi::Embeddings::Vector.instance
end

def hyde_embedding(search_term)
Expand All @@ -52,16 +52,14 @@ def hyde_embedding(search_term)

Discourse
.cache
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) }
.fetch(embedding_key, expires_in: 1.week) { vector.vector_from(hypothetical_post) }
end

def embedding(search_term)
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_model)

Discourse
.cache
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(search_term) }
Discourse.cache.fetch(embedding_key, expires_in: 1.week) { vector.vector_from(search_term) }
end

# this ensures the candidate topics are over selected
Expand All @@ -84,7 +82,7 @@ def search_for_topics(query, page = 1, hyde: true)

over_selection_limit = limit * OVER_SELECTION_FACTOR

schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)

candidate_topic_ids =
schema.asymmetric_similarity_search(
Expand Down Expand Up @@ -114,7 +112,7 @@ def quick_search(query)

return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length

vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector = DiscourseAi::Embeddings::Vector.instance

digest = OpenSSL::Digest::SHA1.hexdigest(search_term)

Expand All @@ -129,12 +127,12 @@ def quick_search(query)
Discourse
.cache
.fetch(embedding_key, expires_in: 1.week) do
vector_rep.vector_from(search_term, asymetric: true)
vector.vector_from(search_term, asymetric: true)
end

candidate_post_ids =
DiscourseAi::Embeddings::Schema
.for(Post, vector: vector_rep)
.for(Post, vector_def: vector.vdef)
.asymmetric_similarity_search(
search_term_embedding,
limit: max_semantic_results_per_page,
Expand Down
17 changes: 13 additions & 4 deletions lib/embeddings/strategies/truncation.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,28 @@ def version
1
end

def prepare_text_from(target, tokenizer, max_length)
def prepare_target_text(target, vdef)
max_length = vdef.max_sequence_length - 2

case target
when Topic
topic_truncation(target, tokenizer, max_length)
topic_truncation(target, vdef.tokenizer, max_length)
when Post
post_truncation(target, tokenizer, max_length)
post_truncation(target, vdef.tokenizer, max_length)
when RagDocumentFragment
tokenizer.truncate(target.fragment, max_length)
vdef.tokenizer.truncate(target.fragment, max_length)
else
raise ArgumentError, "Invalid target type"
end
end

def prepare_query_text(text, vdef, asymetric: false)
qtext = asymetric ? "#{vdef.asymmetric_query_prefix} #{text}" : text
max_length = vdef.max_sequence_length - 2

vdef.tokenizer.truncate(text, max_length)
end

private

def topic_information(topic)
Expand Down
Loading

0 comments on commit 3b3fac2

Please sign in to comment.