Skip to content

Commit

Permalink
Fixes: Self-query fixes (#2255)
Browse files Browse the repository at this point in the history
* init

* fixed supabase rpc call filter merger

* added tests for default metadata filtering for self query. Made weaviate self query less brittle by adding parse to types. added merge filters for default and generated filters

* main merged

* renamed mergeFilterMethod to mergeFiltersOperator

* removed console.log from leftover from testing

* oops formatting

* made type more strict for self-query, with autocompletion with ts

* skip weaviate self query integration tests

* small fixes

* formatting

* fixed integration tests

* Update env var name

---------

Co-authored-by: jacoblee93 <jacoblee93@gmail.com>
  • Loading branch information
ppramesi and jacoblee93 authored Aug 15, 2023
1 parent ce66e81 commit ff230cf
Show file tree
Hide file tree
Showing 15 changed files with 1,391 additions and 92 deletions.
16 changes: 6 additions & 10 deletions langchain/src/chains/query_constructor/ir.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { VectorStore } from "../../vectorstores/base.js";

export type AND = "and";
export type OR = "or";
export type NOT = "not";
Expand Down Expand Up @@ -28,10 +30,7 @@ export const Comparators: { [key: string]: Comparator } = {
gte: "gte",
};

export type VisitorResult =
| VisitorOperationResult
| VisitorComparisonResult
| VisitorStructuredQueryResult;
export type VisitorResult = VisitorOperationResult | VisitorComparisonResult;

export type VisitorOperationResult = {
[operator: string]: VisitorResult[];
Expand All @@ -44,18 +43,15 @@ export type VisitorComparisonResult = {
};

export type VisitorStructuredQueryResult = {
filter?:
| VisitorStructuredQueryResult
| VisitorComparisonResult
| VisitorOperationResult;
filter?: VisitorComparisonResult | VisitorOperationResult;
};

export abstract class Visitor {
export abstract class Visitor<T extends VectorStore = VectorStore> {
declare VisitOperationOutput: object;

declare VisitComparisonOutput: object;

declare VisitStructuredQueryOutput: { filter?: object };
declare VisitStructuredQueryOutput: { filter?: T["FilterType"] };

abstract allowedOperators: Operator[];

Expand Down
4 changes: 3 additions & 1 deletion langchain/src/chains/query_constructor/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ statements): one or more statements to apply the operation to
Make sure that you only use the comparators and logical operators listed above and \
no others.
Make sure that filters only refer to attributes that exist in the data source.
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
Make sure that filters only use format \`YYYY-MM-DD\` when handling timestamp data typed values.
Make sure that filters take into account the descriptions of attributes and only make \
comparisons that are feasible given the type of data being stored.
Make sure that filters are only used as needed. If there are no filters that should be \
Expand All @@ -122,7 +124,7 @@ export const DEFAULT_SUFFIX = `\
Data Source:
\`\`\`json
{{{{
"content": {content},
"content": "{content}",
"attributes": {attributes}
}}}}
\`\`\`
Expand Down
49 changes: 47 additions & 2 deletions langchain/src/retrievers/self_query/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,29 @@ import {
VisitorResult,
VisitorStructuredQueryResult,
} from "../../chains/query_constructor/ir.js";
import { VectorStore } from "../../vectorstores/base.js";
import { isFilterEmpty } from "./utils.js";

export type TranslatorOpts = {
allowedOperators: Operator[];
allowedComparators: Comparator[];
};

export abstract class BaseTranslator extends Visitor {
export abstract class BaseTranslator<
T extends VectorStore = VectorStore
> extends Visitor<T> {
abstract formatFunction(func: Operator | Comparator): string;

abstract mergeFilters(
defaultFilter: this["VisitStructuredQueryOutput"]["filter"] | undefined,
generatedFilter: this["VisitStructuredQueryOutput"]["filter"] | undefined,
mergeType?: "and" | "or" | "replace"
): this["VisitStructuredQueryOutput"]["filter"] | undefined;
}

export class BasicTranslator extends BaseTranslator {
export class BasicTranslator<
T extends VectorStore = VectorStore
> extends BaseTranslator<T> {
declare VisitOperationOutput: VisitorOperationResult;

declare VisitComparisonOutput: VisitorComparisonResult;
Expand Down Expand Up @@ -106,4 +118,37 @@ export class BasicTranslator extends BaseTranslator {
}
return nextArg;
}

mergeFilters(
defaultFilter: VisitorStructuredQueryResult["filter"] | undefined,
generatedFilter: VisitorStructuredQueryResult["filter"] | undefined,
mergeType = "and"
): VisitorStructuredQueryResult["filter"] | undefined {
if (isFilterEmpty(defaultFilter) && isFilterEmpty(generatedFilter)) {
return undefined;
}
if (isFilterEmpty(defaultFilter) || mergeType === "replace") {
if (isFilterEmpty(generatedFilter)) {
return undefined;
}
return generatedFilter;
}
if (isFilterEmpty(generatedFilter)) {
if (mergeType === "and") {
return undefined;
}
return defaultFilter;
}
if (mergeType === "and") {
return {
$and: [defaultFilter, generatedFilter],
};
} else if (mergeType === "or") {
return {
$or: [defaultFilter, generatedFilter],
};
} else {
throw new Error("Unknown merge type");
}
}
}
3 changes: 2 additions & 1 deletion langchain/src/retrievers/self_query/chroma.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { Comparators, Operators } from "../../chains/query_constructor/ir.js";
import { Chroma } from "../../vectorstores/chroma.js";
import { BasicTranslator } from "./base.js";

export class ChromaTranslator extends BasicTranslator {
export class ChromaTranslator<T extends Chroma> extends BasicTranslator<T> {
constructor() {
super({
allowedOperators: [Operators.and, Operators.or],
Expand Down
39 changes: 37 additions & 2 deletions langchain/src/retrievers/self_query/functional.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
} from "../../chains/query_constructor/ir.js";
import { Document } from "../../document.js";
import { BaseTranslator } from "./base.js";
import { isFilterEmpty } from "./utils.js";

type ValueType = {
eq: string | number;
Expand All @@ -26,7 +27,9 @@ export class FunctionalTranslator extends BaseTranslator {

declare VisitComparisonOutput: FunctionFilter;

declare VisitStructuredQueryOutput: { filter: FunctionFilter };
declare VisitStructuredQueryOutput:
| { filter: FunctionFilter }
| { [k: string]: never };

allowedOperators: Operator[] = [Operators.and, Operators.or];

Expand Down Expand Up @@ -132,12 +135,44 @@ export class FunctionalTranslator extends BaseTranslator {
query: StructuredQuery
): this["VisitStructuredQueryOutput"] {
if (!query.filter) {
return { filter: () => false };
return {};
}
const filterFunction = query.filter?.accept(this);
if (typeof filterFunction !== "function") {
throw new Error("Structured query filter is not a function");
}
return { filter: filterFunction as FunctionFilter };
}

mergeFilters(
defaultFilter: FunctionFilter,
generatedFilter: FunctionFilter,
mergeType = "and"
): FunctionFilter | undefined {
if (isFilterEmpty(defaultFilter) && isFilterEmpty(generatedFilter)) {
return undefined;
}
if (isFilterEmpty(defaultFilter) || mergeType === "replace") {
if (isFilterEmpty(generatedFilter)) {
return undefined;
}
return generatedFilter;
}
if (isFilterEmpty(generatedFilter)) {
if (mergeType === "and") {
return undefined;
}
return defaultFilter;
}

if (mergeType === "and") {
return (document: Document) =>
defaultFilter(document) && generatedFilter(document);
} else if (mergeType === "or") {
return (document: Document) =>
defaultFilter(document) || generatedFilter(document);
} else {
throw new Error("Unknown merge type");
}
}
}
68 changes: 43 additions & 25 deletions langchain/src/retrievers/self_query/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { LLMChain } from "../../chains/llm_chain.js";
import {
QueryConstructorChainOptions,
Expand All @@ -13,44 +14,51 @@ import { CallbackManagerForRetrieverRun } from "../../callbacks/manager.js";

export { BaseTranslator, BasicTranslator, FunctionalTranslator };

export interface SelfQueryRetrieverArgs extends BaseRetrieverInput {
vectorStore: VectorStore;
structuredQueryTranslator: BaseTranslator;
export interface SelfQueryRetrieverArgs<T extends VectorStore>
extends BaseRetrieverInput {
vectorStore: T;
structuredQueryTranslator: BaseTranslator<T>;
llmChain: LLMChain;
verbose?: boolean;
useOriginalQuery?: boolean;
searchParams?: {
k?: number;
filter?: VectorStore["FilterType"];
filter?: T["FilterType"];
mergeFiltersOperator?: "or" | "and" | "replace";
};
}
export class SelfQueryRetriever

export class SelfQueryRetriever<T extends VectorStore>
extends BaseRetriever
implements SelfQueryRetrieverArgs
implements SelfQueryRetrieverArgs<T>
{
get lc_namespace() {
return ["langchain", "retrievers", "self_query"];
}

vectorStore: VectorStore;
vectorStore: T;

llmChain: LLMChain;

verbose?: boolean;

structuredQueryTranslator: BaseTranslator;
structuredQueryTranslator: BaseTranslator<T>;

useOriginalQuery = false;

searchParams?: {
k?: number;
filter?: VectorStore["FilterType"];
filter?: T["FilterType"];
mergeFiltersOperator?: "or" | "and" | "replace";
} = { k: 4 };

constructor(options: SelfQueryRetrieverArgs) {
constructor(options: SelfQueryRetrieverArgs<T>) {
super(options);
this.vectorStore = options.vectorStore;
this.llmChain = options.llmChain;
this.verbose = options.verbose ?? false;
this.searchParams = options.searchParams ?? this.searchParams;

this.useOriginalQuery = options.useOriginalQuery ?? this.useOriginalQuery;
this.structuredQueryTranslator = options.structuredQueryTranslator;
}

Expand All @@ -65,31 +73,41 @@ export class SelfQueryRetriever
runManager?.getChild("llm_chain")
);

const generatedStructuredQuery = output as StructuredQuery;

const nextArg = this.structuredQueryTranslator.visitStructuredQuery(
output as StructuredQuery
generatedStructuredQuery
);

if (nextArg.filter) {
return this.vectorStore.similaritySearch(
query,
this.searchParams?.k,
nextArg.filter,
runManager?.getChild("vectorstore")
);
const filter = this.structuredQueryTranslator.mergeFilters(
this.searchParams?.filter,
nextArg.filter,
this.searchParams?.mergeFiltersOperator
);

const generatedQuery = generatedStructuredQuery.query;
let myQuery = query;

if (!this.useOriginalQuery && generatedQuery && generatedQuery.length > 0) {
myQuery = generatedQuery;
}

if (!filter) {
return [];
} else {
return this.vectorStore.similaritySearch(
query,
myQuery,
this.searchParams?.k,
this.searchParams?.filter,
filter,
runManager?.getChild("vectorstore")
);
}
}

static fromLLM(
static fromLLM<T extends VectorStore>(
options: QueryConstructorChainOptions &
Omit<SelfQueryRetrieverArgs, "llmChain">
): SelfQueryRetriever {
Omit<SelfQueryRetrieverArgs<T>, "llmChain">
): SelfQueryRetriever<T> {
const {
structuredQueryTranslator,
allowedComparators,
Expand All @@ -111,7 +129,7 @@ export class SelfQueryRetriever
allowedOperators:
allowedOperators ?? structuredQueryTranslator.allowedOperators,
});
return new SelfQueryRetriever({
return new SelfQueryRetriever<T>({
...rest,
llmChain,
vectorStore,
Expand Down
5 changes: 4 additions & 1 deletion langchain/src/retrievers/self_query/pinecone.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import { Comparators, Operators } from "../../chains/query_constructor/ir.js";
import { PineconeStore } from "../../vectorstores/pinecone.js";
import { BasicTranslator } from "./base.js";

export class PineconeTranslator extends BasicTranslator {
export class PineconeTranslator<
T extends PineconeStore
> extends BasicTranslator<T> {
constructor() {
super({
allowedOperators: [Operators.and, Operators.or],
Expand Down
Loading

1 comment on commit ff230cf

@vercel
Copy link

@vercel vercel bot commented on ff230cf Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.