Skip to content

Commit

Permalink
fix(langchain): Fix Groq import for hub (#7620)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Jan 29, 2025
1 parent a4404fb commit 3c53dcd
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 11 deletions.
30 changes: 26 additions & 4 deletions langchain/src/hub/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,6 @@ export function generateModelImportMap(
) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const modelImportMap: Record<string, any> = {};
// TODO: Fix in 0.4.0. We can't get lc_id without instantiating the class, so we
// must put them inline here. In the future, make this less hacky
// This should probably use dynamic imports and have a web-only entrypoint
// in a future breaking release
if (modelClass !== undefined) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const modelLcName = (modelClass as any)?.lc_name();
Expand Down Expand Up @@ -130,3 +126,29 @@ export function generateModelImportMap(
}
return modelImportMap;
}

export function generateOptionalImportMap(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
modelClass?: new (...args: any[]) => BaseLanguageModel
) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const optionalImportMap: Record<string, any> = {};
if (modelClass !== undefined) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const modelLcName = (modelClass as any)?.lc_name();
let optionalImportMapKey;
if (modelLcName === "ChatGoogleGenerativeAI") {
optionalImportMapKey = "langchain_google_genai/chat_models";
} else if (modelLcName === "ChatBedrockConverse") {
optionalImportMapKey = "langchain_aws/chat_models";
} else if (modelLcName === "ChatGroq") {
optionalImportMapKey = "langchain_groq/chat_models";
}
if (optionalImportMapKey !== undefined) {
optionalImportMap[optionalImportMapKey] = {
[modelLcName]: modelClass,
};
}
}
return optionalImportMap;
}
9 changes: 7 additions & 2 deletions langchain/src/hub/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import { Runnable } from "@langchain/core/runnables";
import type { BaseLanguageModel } from "@langchain/core/language_models/base";
import { load } from "../load/index.js";
import { basePush, basePull, generateModelImportMap } from "./base.js";
import {
basePush,
basePull,
generateModelImportMap,
generateOptionalImportMap,
} from "./base.js";

export { basePush as push };

Expand Down Expand Up @@ -36,7 +41,7 @@ export async function pull<T extends Runnable>(
const loadedPrompt = await load<T>(
JSON.stringify(promptObject.manifest),
undefined,
undefined,
generateOptionalImportMap(options?.modelClass),
generateModelImportMap(options?.modelClass)
);
return loadedPrompt;
Expand Down
9 changes: 7 additions & 2 deletions langchain/src/hub/node.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { Runnable } from "@langchain/core/runnables";
import { basePush, basePull, generateModelImportMap } from "./base.js";
import {
basePush,
basePull,
generateModelImportMap,
generateOptionalImportMap,
} from "./base.js";
import { load } from "../load/index.js";

// TODO: Make this the default, add web entrypoint in next breaking release
Expand Down Expand Up @@ -55,7 +60,7 @@ export async function pull<T extends Runnable>(
const loadedPrompt = await load<T>(
JSON.stringify(promptObject.manifest),
undefined,
undefined,
generateOptionalImportMap(modelClass),
generateModelImportMap(modelClass)
);
return loadedPrompt;
Expand Down
7 changes: 4 additions & 3 deletions langchain/src/hub/tests/hub.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,13 @@ test("Test LangChain Hub while loading model", async () => {
});

test("Test LangChain Hub while loading model with dynamic imports", async () => {
const pulledPrompt = await nodePull("jacob/lahzo-testing", {
const pulledPrompt = await nodePull("jacob/groq-test", {
includeModel: true,
});
const res = await pulledPrompt.invoke({
agent: { name: "testing" },
messages: [new AIMessage("foo")],
question:
"Who is the current president of the USA as of today? You must use the provided tool for the latest info.",
});
expect(res).toBeInstanceOf(AIMessage);
expect(res.tool_calls?.length).toEqual(1);
});

0 comments on commit 3c53dcd

Please sign in to comment.