Skip to content

Commit

Permalink
fix: Add fixes regarding PR comments, change models in tests (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
FilipZmijewski authored Feb 4, 2025
1 parent 818c888 commit f818281
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 44 deletions.
19 changes: 9 additions & 10 deletions libs/langchain-community/src/chat_models/ibm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ function _convertToolToWatsonxTool(

function _convertMessagesToWatsonxMessages(
messages: BaseMessage[],
model: string
model?: string
): TextChatResultMessage[] {
const getRole = (role: MessageType) => {
switch (role) {
Expand All @@ -174,7 +174,7 @@ function _convertMessagesToWatsonxMessages(
return message.tool_calls
.map((toolCall) => ({
...toolCall,
id: _convertToValidToolId(model, toolCall.id ?? ""),
id: _convertToValidToolId(model ?? "", toolCall.id ?? ""),
}))
.map(convertLangChainToolCallToOpenAI) as TextChatToolCall[];
}
Expand All @@ -189,7 +189,7 @@ function _convertMessagesToWatsonxMessages(
role: getRole(message._getType()),
content,
name: message.name,
tool_call_id: _convertToValidToolId(model, message.tool_call_id),
tool_call_id: _convertToValidToolId(model ?? "", message.tool_call_id),
};
}

Expand Down Expand Up @@ -252,7 +252,7 @@ function _watsonxResponseToChatMessage(
function _convertDeltaToMessageChunk(
delta: WatsonxDeltaStream,
rawData: TextChatResponse,
model: string,
model?: string,
usage?: TextChatUsage,
defaultRole?: TextChatMessagesTextChatMessageAssistant.Constants.Role
) {
Expand All @@ -268,7 +268,7 @@ function _convertDeltaToMessageChunk(
} => ({
...toolCall,
index,
id: _convertToValidToolId(model, toolCall.id),
id: _convertToValidToolId(model ?? "", toolCall.id),
type: "function",
})
)
Expand Down Expand Up @@ -321,7 +321,7 @@ function _convertDeltaToMessageChunk(
return new ToolMessageChunk({
content,
additional_kwargs,
tool_call_id: _convertToValidToolId(model, rawToolCalls?.[0].id),
tool_call_id: _convertToValidToolId(model ?? "", rawToolCalls?.[0].id),
});
} else if (role === "function") {
return new FunctionMessageChunk({
Expand Down Expand Up @@ -410,7 +410,7 @@ export class ChatWatsonx<
};
}

model: string;
model?: string;

version = "2024-05-31";

Expand Down Expand Up @@ -523,7 +523,6 @@ export class ChatWatsonx<
const { signal, promptIndex, ...rest } = options;
if (this.idOrName && Object.keys(rest).length > 0)
throw new Error("Options cannot be provided to a deployed model");
if (this.idOrName) return undefined;

const params = {
maxTokens: options.maxTokens ?? this.maxTokens,
Expand Down Expand Up @@ -564,9 +563,9 @@ export class ChatWatsonx<
| { idOrName: string }
| { projectId: string; modelId: string }
| { spaceId: string; modelId: string } {
if (this.projectId)
if (this.projectId && this.model)
return { projectId: this.projectId, modelId: this.model };
else if (this.spaceId)
else if (this.spaceId && this.model)
return { spaceId: this.spaceId, modelId: this.model };
else if (this.idOrName) return { idOrName: this.idOrName };
else throw new Error("No scope id provided");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ class ChatWatsonxStandardIntegrationTests extends ChatModelIntegrationTests<
},
});
}

async testInvokeMoreComplexTools() {
this.skipTestMessage(
"testInvokeMoreComplexTools",
"ChatWatsonx",
"Watsonx does not support tool schemas which contain object with unknown/any parameters." +
"Watsonx only supports objects in schemas when the parameters are defined."
);
}
}

const testClass = new ChatWatsonxStandardIntegrationTests();
Expand Down
9 changes: 5 additions & 4 deletions libs/langchain-community/src/llms/ibm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,11 @@ export class WatsonxLLM<
geneartionsArray[completion].stop_reason =
chunk?.generationInfo?.stop_reason;
geneartionsArray[completion].text += chunk.text;
void runManager?.handleLLMNewToken(chunk.text, {
prompt: promptIdx,
completion: 0,
});
if (chunk.text)
void runManager?.handleLLMNewToken(chunk.text, {
prompt: promptIdx,
completion: 0,
});
}

return geneartionsArray.map((item) => {
Expand Down
62 changes: 32 additions & 30 deletions libs/langchain-community/src/llms/tests/ibm.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ describe("Text generation", () => {
describe("Test invoke method", () => {
test("Correct value", async () => {
const watsonXInstance = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
Expand All @@ -21,7 +21,7 @@ describe("Text generation", () => {

test("Overwritte params", async () => {
const watsonXInstance = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
Expand All @@ -33,7 +33,7 @@ describe("Text generation", () => {

test("Invalid projectId", async () => {
const watsonXInstance = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: "Test wrong value",
Expand All @@ -43,7 +43,7 @@ describe("Text generation", () => {

test("Invalid credentials", async () => {
const watsonXInstance = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: "Test wrong value",
Expand All @@ -56,7 +56,7 @@ describe("Text generation", () => {

test("Wrong value", async () => {
const watsonXInstance = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
Expand All @@ -67,7 +67,7 @@ describe("Text generation", () => {

test("Stop", async () => {
const watsonXInstance = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
Expand All @@ -79,7 +79,7 @@ describe("Text generation", () => {

test("Stop with timeout", async () => {
const watsonXInstance = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: "sdadasdas" as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
Expand All @@ -94,7 +94,7 @@ describe("Text generation", () => {

test("Signal in call options", async () => {
const watsonXInstance = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
Expand All @@ -119,7 +119,7 @@ describe("Text generation", () => {

test("Concurenccy", async () => {
const model = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
maxConcurrency: 1,
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
Expand All @@ -139,7 +139,7 @@ describe("Text generation", () => {
input_token_count: 0,
};
const model = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
maxNewTokens: 1,
maxConcurrency: 1,
Expand Down Expand Up @@ -171,7 +171,7 @@ describe("Text generation", () => {
let streamedText = "";
let usedTokens = 0;
const model = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
Expand All @@ -198,7 +198,7 @@ describe("Text generation", () => {
describe("Test generate methods", () => {
test("Basic usage", async () => {
const model = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
Expand All @@ -213,20 +213,22 @@ describe("Text generation", () => {

test("Stop", async () => {
const model = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
maxNewTokens: 100,
});

const res = await model.generate(
["Print hello world!", "Print hello world hello!"],
[
"Print hello world in JavaScript!!",
"Print hello world twice in Python!",
],
{
stop: ["Hello"],
stop: ["hello"],
}
);

expect(
res.generations
.map((generation) => generation.map((item) => item.text))
Expand All @@ -239,7 +241,7 @@ describe("Text generation", () => {
const nrNewTokens = [0, 0, 0];
const completions = ["", "", ""];
const model = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
Expand Down Expand Up @@ -270,7 +272,7 @@ describe("Text generation", () => {

test("Prompt value", async () => {
const model = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
Expand All @@ -290,7 +292,7 @@ describe("Text generation", () => {
let countedTokens = 0;
let streamedText = "";
const model = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
Expand All @@ -313,15 +315,15 @@ describe("Text generation", () => {

test("Stop", async () => {
const model = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
maxNewTokens: 100,
});

const stream = await model.stream("Print hello world!", {
stop: ["Hello"],
const stream = await model.stream("Print hello world in JavaScript!", {
stop: ["hello"],
});
const chunks = [];
for await (const chunk of stream) {
Expand All @@ -332,7 +334,7 @@ describe("Text generation", () => {

test("Timeout", async () => {
const model = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
Expand All @@ -354,7 +356,7 @@ describe("Text generation", () => {

test("Signal in call options", async () => {
const model = new WatsonxLLM({
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
Expand Down Expand Up @@ -384,7 +386,7 @@ describe("Text generation", () => {
describe("Test getNumToken method", () => {
test("Passing correct value", async () => {
const testProps: WatsonxInputLLM = {
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
Expand All @@ -402,7 +404,7 @@ describe("Text generation", () => {

test("Passing wrong value", async () => {
const testProps: WatsonxInputLLM = {
model: "ibm/granite-13b-chat-v2",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
Expand All @@ -425,7 +427,7 @@ describe("Text generation", () => {
test("Single request callback", async () => {
let callbackFlag = false;
const service = new WatsonxLLM({
model: "mistralai/mistral-large",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
Expand All @@ -445,7 +447,7 @@ describe("Text generation", () => {
test("Single response callback", async () => {
let callbackFlag = false;
const service = new WatsonxLLM({
model: "mistralai/mistral-large",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
Expand All @@ -467,7 +469,7 @@ describe("Text generation", () => {
let callbackFlagReq = false;
let callbackFlagRes = false;
const service = new WatsonxLLM({
model: "mistralai/mistral-large",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
Expand Down Expand Up @@ -495,7 +497,7 @@ describe("Text generation", () => {
let langchainCallback = false;

const service = new WatsonxLLM({
model: "mistralai/mistral-large",
model: "ibm/granite-3-8b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
Expand Down

0 comments on commit f818281

Please sign in to comment.