diff --git a/examples/pipelines/providers/mlx_pipeline.py b/examples/pipelines/providers/mlx_pipeline.py index 39216776..8365ab2e 100644 --- a/examples/pipelines/providers/mlx_pipeline.py +++ b/examples/pipelines/providers/mlx_pipeline.py @@ -2,55 +2,55 @@ title: MLX Pipeline author: justinh-rahb date: 2024-05-27 -version: 1.1 +version: 1.2 license: MIT description: A pipeline for generating text using Apple MLX Framework. requirements: requests, mlx-lm, huggingface-hub -environment_variables: MLX_HOST, MLX_PORT, MLX_MODEL, MLX_STOP, MLX_SUBPROCESS, HUGGINGFACE_TOKEN +environment_variables: MLX_HOST, MLX_PORT, MLX_SUBPROCESS """ from typing import List, Union, Generator, Iterator from schemas import OpenAIChatMessage +from pydantic import BaseModel import requests import os import subprocess import logging from huggingface_hub import login - class Pipeline: + class Valves(BaseModel): + MLX_MODEL: str = "mistralai/Mistral-7B-Instruct-v0.3" + MLX_STOP: str = "[INST]" + HUGGINGFACE_TOKEN: str = "" + def __init__(self): - # Optionally, you can set the id and name of the pipeline. - # Best practice is to not specify the id so that it can be automatically inferred from the filename, so that users can install multiple versions of the same pipeline. - # The identifier must be unique across all pipelines. - # The identifier must be an alphanumeric string that can include underscores or hyphens. It cannot contain spaces, special characters, slashes, or backslashes. - # self.id = "mlx_pipeline" + self.id = "mlx_pipeline" self.name = "MLX Pipeline" + self.valves = self.Valves() + self.update_valves() + self.host = os.getenv("MLX_HOST", "localhost") self.port = os.getenv("MLX_PORT", "8080") - self.model = os.getenv("MLX_MODEL", "mistralai/Mistral-7B-Instruct-v0.2") - self.stop_sequence = os.getenv("MLX_STOP", "[INST]").split( - "," - ) # Default stop sequence is [INST] self.subprocess = os.getenv("MLX_SUBPROCESS", "true").lower() == "true" - self.huggingface_token = os.getenv("HUGGINGFACE_TOKEN", None) - - if self.huggingface_token: - login(self.huggingface_token) if self.subprocess: self.start_mlx_server() + def update_valves(self): + if self.valves.HUGGINGFACE_TOKEN: + login(self.valves.HUGGINGFACE_TOKEN) + self.stop_sequence = self.valves.MLX_STOP.split(",") + def start_mlx_server(self): if not os.getenv("MLX_PORT"): self.port = self.find_free_port() - command = f"mlx_lm.server --model {self.model} --port {self.port}" - self.server_process = subprocess.Popen(command, shell=True) - logging.info(f"Started MLX server on port {self.port}") + command = f"mlx_lm.server --model {self.valves.MLX_MODEL} --port {self.port}" + self.server_process = subprocess.Popen(command, shell=True) + logging.info(f"Started MLX server on port {self.port}") def find_free_port(self): import socket - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) port = s.getsockname()[1] @@ -65,6 +65,13 @@ async def on_shutdown(self): self.server_process.terminate() logging.info(f"Terminated MLX server on port {self.port}") + async def on_valves_updated(self): + self.update_valves() + if self.subprocess and hasattr(self, "server_process"): + self.server_process.terminate() + logging.info(f"Terminated MLX server on port {self.port}") + self.start_mlx_server() + def pipe( self, user_message: str, model_id: str, messages: List[dict], body: dict ) -> Union[str, Generator, Iterator]: @@ -73,18 +80,17 @@ def pipe( url = f"http://{self.host}:{self.port}/v1/chat/completions" headers = {"Content-Type": "application/json"} - # Extract and validate parameters from the request body max_tokens = body.get("max_tokens", 4096) if not isinstance(max_tokens, int) or max_tokens < 0: - max_tokens = 4096 # Default to 4096 if invalid + max_tokens = 4096 temperature = body.get("temperature", 0.8) if not isinstance(temperature, (int, float)) or temperature < 0: - temperature = 0.8 # Default to 0.8 if invalid + temperature = 0.8 repeat_penalty = body.get("repeat_penalty", 1.0) if not isinstance(repeat_penalty, (int, float)) or repeat_penalty < 0: - repeat_penalty = 1.0 # Default to 1.0 if invalid + repeat_penalty = 1.0 payload = { "messages": messages, @@ -106,4 +112,4 @@ def pipe( else: return r.json() except Exception as e: - return f"Error: {e}" + return f"Error: {e}" \ No newline at end of file