Home
Softono
b

blaizzy

Professional software vendor delivering innovative solutions on the Softono platform. Specialized in both open-source and proprietary software development.

Total Products
2

Software by blaizzy

mlx-vlm
Open Source

mlx-vlm

[![Upload Python Package](https://github.com/Blaizzy/mlx-vlm/actions/workflows/python-publish.yml/badge.svg)](https://github.com/Blaizzy/mlx-vlm/actions/workflows/python-publish.yml) # MLX-VLM MLX-VLM is a package for inference and fine-tuning of Vision Language Models (VLMs) and Omni Models (VLMs with audio and video support) on your Mac using MLX. ## Table of Contents - [Installation](#installation) - [Usage](#usage) - [Command Line Interface (CLI)](#command-line-interface-cli) - [Thinking Budget](#thinking-budget) - [Speculative Decoding](#speculative-decoding) - [DFlash (Qwen3.5)](#dflash-qwen35) - [Gemma 4 MTP](#gemma-4-mtp) - [Chat UI with Gradio](#chat-ui-with-gradio) - [Python Script](#python-script) - [Server (FastAPI)](#server-fastapi) - [Continuous Batching](#continuous-batching) - [Automatic Prefix Caching (APC)](#automatic-prefix-caching-apc) - [KV Cache Quantization](#kv-cache-quantization) - [Activation Quantization (CUDA)](#activation-quantization-cuda) - [Multi-Image Chat Support](#multi-image-chat-support) - [Supported Models](#supported-models) - [Usage Examples](#usage-examples) - [Model-Specific Documentation](#model-specific-documentation) - [Vision Feature Caching](#vision-feature-caching) - [TurboQuant KV Cache](#turboquant-kv-cache) - [Distributed Inference](#distributed-inference) - [Fine-tuning](#fine-tuning) ## Model-Specific Documentation Some models have detailed documentation with prompt formats, examples, and best practices: | Model | Documentation | |-------|---------------| | DeepSeek-OCR | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/deepseekocr/README.md) | | DeepSeek-OCR-2 | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/deepseekocr_2/README.md) | | DOTS-OCR | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/dots_ocr/README.md) | | DOTS-MOCR | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/dots_ocr/README.md) | | ERNIE 4.5 VL | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/ernie4_5_moe_vl/README.md) | | GLM-OCR | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/glm_ocr/README.md) | | Phi-4 Reasoning Vision | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/phi4_siglip/README.md) | | MiniCPM-o | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/minicpmo/README.md) | | PaddleOCR-VL | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/paddleocr_vl/README.md) | | Phi-4 Multimodal | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/phi4mm/README.md) | | MolmoPoint | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/molmo_point/README.md) | | LocateAnything | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/locateanything/README.md) | | Moondream3 | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/moondream3/README.md) | | Gemma 4 | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/gemma4/README.md) | | Falcon-OCR | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/falcon_ocr/README.md) | | Granite Vision 3.2 | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/granite_vision/README.md) | | Granite 4.0 Vision | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/granite4_vision/README.md) | | MiniCPM-V 4.6 | [Docs](https://github.com/Blaizzy/mlx-vlm/blob/main/mlx_vlm/models/minicpmv4_6/README.md) | ## Installation The easiest way to get started is to install the `mlx-vlm` package using pip: ```sh pip install -U mlx-vlm ``` ## Usage ### Command Line Interface (CLI) Generate output from a model using the CLI: ```sh # Text generation mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --prompt "Hello, how are you?" # Image generation mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --temperature 0.0 --image http://images.cocodataset.org/val2017/000000039769.jpg # Audio generation (New) mlx_vlm.generate --model mlx-community/gemma-3n-E2B-it-4bit --max-tokens 100 --prompt "Describe what you hear" --audio /path/to/audio.wav # Multi-modal generation (Image + Audio) mlx_vlm.generate --model mlx-community/gemma-3n-E2B-it-4bit --max-tokens 100 --prompt "Describe what you see and hear" --image /path/to/image.jpg --audio /path/to/audio.wav ``` #### Thinking Budget For thinking models (e.g., Qwen3.5), you can limit the number of tokens spent in the thinking block: ```sh mlx_vlm.generate --model mlx-community/Qwen3.5-2B-4bit \ --thinking-budget 50 \ --thinking-start-token "<think>" \ --thinking-end-token "</think>" \ --enable-thinking \ --prompt "Solve 2+2" ``` | Flag | Description | |------|-------------| | `--enable-thinking` | Activate thinking mode in the chat template | | `--thinking-budget` | Max tokens allowed inside the thinking block | | `--thinking-start-token` | Token that opens a thinking block (default: `<think>`) | | `--thinking-end-token` | Token that closes a thinking block (default: `</think>`) | When the budget is exceeded, the model is forced to emit `\n</think>` and transition to the answer. If `--enable-thinking` is passed but the model's chat template does not support it, the budget is applied only if the model generates the start token on its own. On the server, thinking mode is disabled by default. Start the server with `--enable-thinking` to make thinking mode the default for requests that do not specify it: ```sh mlx_vlm.server --model Qwen/Qwen3.5-4B --enable-thinking ``` You can also set server defaults for the thinking budget and delimiter tokens: ```sh mlx_vlm.server --model Qwen/Qwen3.5-4B \ --enable-thinking \ --thinking-budget 512 \ --thinking-start-token "<think>" \ --thinking-end-token "</think>" ``` Requests can override the server defaults with `enable_thinking`, `thinking_budget`, `thinking_start_token`, or `thinking_end_token`. ### Speculative Decoding Speed up generation by drafting several candidate tokens with a small "drafter" model and verifying them in a single target forward pass. Three drafter families are supported. | Flag | Description | |------|-------------| | `--draft-model` | HuggingFace repo or local path for the drafter | | `--draft-kind` | Drafter family — `dflash` (default), `eagle3`, or `mtp` (Gemma 4) | | `--draft-block-size` | Override the drafter's configured block size | See [docs/usage.md](docs/usage.md) for Python API examples including batch generation. #### DFlash (Qwen3.5) A lightweight block-diffusion drafter that predicts multiple tokens per round, typically 2–3× faster. ```sh # Text generation with speculative decoding mlx_vlm.generate --model Qwen/Qwen3.5-4B \ --draft-model z-lab/Qwen3.5-4B-DFlash \ --prompt "Write a quicksort in Python." \ --max-tokens 512 --temperature 0 --enable-thinking # Also works with images mlx_vlm.generate --model Qwen/Qwen3.5-4B \ --draft-model z-lab/Qwen3.5-4B-DFlash \ --image examples/images/cats.jpg \ --prompt "Describe this image." \ --max-tokens 256 --temperature 0 --enable-thinking # Server with speculative decoding mlx_vlm.server --model Qwen/Qwen3.5-4B \ --draft-model z-lab/Qwen3.5-4B-DFlash ``` DFlash draft-cache windowing is available from the Python API. During speculative decoding the target model still verifies every proposed token with its full KV cache; this knob only changes the DFlash drafter cache. When `draft_window_size` is set, the drafter keeps at most that many recent committed tokens in its own KV cache instead of attending over the full generated prefix. That reduces draft-side cache length and memory, but it can lower acceptance because the drafter has less context than the target verifier. On MLX, the full draft cache is usually faster for Qwen3.5 DFlash, so windowing defaults to `None`; set it only when you want to experiment with this compact recent-token cache tradeoff: ```python from mlx_vlm import load from mlx_vlm.generate import generate from mlx_vlm.speculative.drafters import load_drafter model, processor = load("Qwen/Qwen3.5-4B") draft_model, draft_kind = load_drafter("z-lab/Qwen3.5-4B-DFlash") draft_model.config.draft_window_size = 256 # None disables windowing result = generate( model, processor, "Write a quicksort in Python.", max_tokens=512, temperature=0, draft_model=draft_model, draft_kind=draft_kind, ) ``` #### Gemma 4 MTP [Multi-Token Prediction](https://ai.google.dev/gemma/docs/mtp/mtp): Google's 4-layer "assistant" drafter that shares K/V with the target and drafts multiple tokens autoregressively from a constant position. Pass `--draft-kind mtp` to dispatch the MTP round-loop. ```sh mlx_vlm.generate --model mlx-community/gemma-4-31B-it-bf16 \ --draft-model mlx-community/gemma-4-31B-it-assistant-bf16 \ --draft-kind mtp --draft-block-size 4 \ --prompt "Explain speculative decoding in 3 sentences." \ --max-tokens 256 --temperature 0 # Server mlx_vlm.server --model mlx-community/gemma-4-31B-it-bf16 \ --draft-model mlx-community/gemma-4-31B-it-assistant-bf16 \ --draft-kind mtp --draft-block-size 4 ``` Supported pairings (target ↔ drafter): | Target | Drafter | |---------------------------------|------------------------------------------| | `mlx-community/gemma-4-E2B-it-bf16` | `mlx-community/gemma-4-E2B-it-assistant-bf16` | | `mlx-community/gemma-4-E4B-it-bf16` | `mlx-community/gemma-4-E4B-it-assistant-bf16` | | `mlx-community/gemma-4-26B-A4B-it-bf16` | `mlx-community/gemma-4-26B-A4B-it-assistant-bf16` | | `mlx-community/gemma-4-31B-it-bf16` | `mlx-community/gemma-4-31B-it-assistant-bf16` | Measured speedups (greedy, byte-identical output): up to **3.94×** on 26B-A4B and **2.29×** on 31B at B=4. See [`mlx_vlm/speculative/drafters/gemma4_assistant/README.md`](mlx_vlm/speculative/drafters/gemma4_assistant/README.md) for full sweeps and architecture notes. #### Gemma 4 EAGLE-3 [EAGLE-3](https://sgl-project.github.io/SpecForge/concepts/EAGLE3.html) drafts from three target hidden-state captures with a lightweight one-layer speculator. The Red Hat Speculators checkpoint auto-detects as `--draft-kind eagle3`. ```sh mlx_vlm.generate --model mlx-community/gemma-4-31B-it-bf16 \ --draft-model RedHatAI/gemma-4-31B-it-speculator.eagle3 \ --prompt "Explain speculative decoding in 3 sentences." \ --max-tokens 256 --temperature 0 # Server mlx_vlm.server --model mlx-community/gemma-4-31B-it-bf16 \ --draft-model RedHatAI/gemma-4-31B-it-speculator.eagle3 ``` ### Chat UI with Gradio Launch a chat interface using Gradio: ```sh mlx_vlm.chat_ui --model mlx-community/Qwen2-VL-2B-Instruct-4bit ``` ### Python Script Here's an example of how to use MLX-VLM in a Python script: ```python import mlx.core as mx from mlx_vlm import load, generate from mlx_vlm.prompt_utils import apply_chat_template from mlx_vlm.utils import load_config # Load the model model_path = "mlx-community/Qwen2-VL-2B-Instruct-4bit" model, processor = load(model_path) config = load_config(model_path) # Prepare input image = ["http://images.cocodataset.org/val2017/000000039769.jpg"] # image = [Image.open("...")] can also be used with PIL.Image.Image objects prompt = "Describe this image." # Apply chat template formatted_prompt = apply_chat_template( processor, config, prompt, num_images=len(image) ) # Generate output output = generate(model, processor, formatted_prompt, image, verbose=False) print(output) ``` #### Audio Example ```python from mlx_vlm import load, generate from mlx_vlm.prompt_utils import apply_chat_template from mlx_vlm.utils import load_config # Load model with audio support model_path = "mlx-community/gemma-3n-E2B-it-4bit" model, processor = load(model_path) config = model.config # Prepare audio input audio = ["/path/to/audio1.wav", "/path/to/audio2.mp3"] prompt = "Describe what you hear in these audio files." # Apply chat template with audio formatted_prompt = apply_chat_template( processor, config, prompt, num_audios=len(audio) ) # Generate output with audio output = generate(model, processor, formatted_prompt, audio=audio, verbose=False) print(output) ``` #### Multi-Modal Example (Image + Audio) ```python from mlx_vlm import load, generate from mlx_vlm.prompt_utils import apply_chat_template from mlx_vlm.utils import load_config # Load multi-modal model model_path = "mlx-community/gemma-3n-E2B-it-4bit" model, processor = load(model_path) config = model.config # Prepare inputs image = ["/path/to/image.jpg"] audio = ["/path/to/audio.wav"] prompt = "" # Apply chat template formatted_prompt = apply_chat_template( processor, config, prompt, num_images=len(image), num_audios=len(audio) ) # Generate output output = generate(model, processor, formatted_prompt, image, audio=audio, verbose=False) print(output) ``` ### Server (FastAPI) Start the server: ```sh mlx_vlm.server --port 8080 # Preload a model at startup (Hugging Face repo or local path) mlx_vlm.server --model <hf_repo_or_local_path> # Preload a model with adapter mlx_vlm.server --model <hf_repo_or_local_path> --adapter-path <adapter_path> # With trust remote code enabled (required for some models) mlx_vlm.server --trust-remote-code # Enable thinking mode by default for requests that do not override it mlx_vlm.server --model Qwen/Qwen3.5-4B --enable-thinking # Configure thinking defaults at startup mlx_vlm.server --model Qwen/Qwen3.5-4B \ --enable-thinking \ --thinking-budget 512 \ --thinking-start-token "<think>" \ --thinking-end-token "</think>" ``` #### Server Options - `--model`: Preload a model at server startup, accepts a Hugging Face repo ID or local path (optional, loads lazily on first request if omitted) - `--adapter-path`: Path for adapter weights to use with the preloaded model - `--draft-model`: Speculative drafter path or HF id (e.g. `z-lab/Qwen3.5-4B-DFlash`, `RedHatAI/gemma-4-31B-it-speculator.eagle3`, `google/gemma-4-31B-it-assistant`) — enables speculative decoding for ~2× or higher throughput - `--draft-kind`: Drafter family — `dflash` (default), `eagle3`, or `mtp` (Gemma 4) - `--draft-block-size`: Override the drafter's configured block size - `--host`: Host address (default: `0.0.0.0`) - `--port`: Port number (default: `8080`) - `--trust-remote-code`: Trust remote code when loading models from Hugging Face Hub - `--enable-thinking`: Enable thinking mode by default for requests that do not set `enable_thinking` - `--thinking-budget`: Default maximum number of tokens allowed inside a thinking block - `--thinking-start-token`: Default token that opens a thinking block - `--thinking-end-token`: Default token that closes a thinking block (`--thinking-eos-token` is also accepted) - `--kv-bits`: Number of bits for KV cache quantization (e.g. `8` for uniform, `3.5` for TurboQuant) - `--kv-quant-scheme`: KV cache quantization backend (`uniform` or `turboquant`) - `--kv-group-size`: Group size for uniform KV cache quantization (default: `64`) - `--max-kv-size`: Maximum KV cache size in tokens - `--vision-cache-size`: Max number of cached vision features (default: `20`) - `--log-level`: Logging level — `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL` (default: `INFO`) You can also set trust remote code via environment variable: ```sh MLX_TRUST_REMOTE_CODE=true mlx_vlm.server ``` The server provides multiple endpoints for different use cases and supports dynamic model loading/unloading with caching (one model at a time). ### Continuous Batching The server supports continuous batching for higher throughput when handling multiple concurrent requests. New requests join the active batch immediately without waiting for existing requests to finish, and mixed batches of image and text-only requests are supported. Continuous batching is enabled automatically when the server loads a model. You can pre-load a model at startup so it's ready to serve immediately: ```sh mlx_vlm.server --port 8080 --model mlx-community/Qwen2.5-VL-3B-Instruct-4bit ``` Verify via the health endpoint: ```sh curl http://localhost:8080/health # {"status":"healthy","loaded_model":"...","apc_enabled":false} ``` If `--model` is omitted, the model is loaded on the first request. ### Automatic Prefix Caching (APC) Automatic Prefix Caching reuses block-level K/V cache state across requests that share the same prefix. It is useful for repeated long documents, long chat histories, or retrieval contexts where each request appends a short new suffix. APC has two tiers: - **Warm memory**: keeps reusable `APCBlock` tensors in process memory. This is the fastest path, but it keeps both the reusable block pool and the runtime `KVCache`. - **Warm disk**: persists cached prefixes as safetensors shards so they survive process restarts. Warm-disk reads build the layer-major prompt cache directly without promoting restored blocks into the `APCBlock` pool; writes can still populate both memory and disk tiers. #### Python Script Use `APCManager` directly when calling `stream_generate`: ```python from pathlib import Path from mlx_vlm import load, stream_generate from mlx_vlm.apc import APCManager, DiskBlockStore from mlx_vlm.prompt_utils import apply_chat_template model_id = "Qwen/Qwen3-VL-4B-Instruct" model, processor = load(model_id) disk = DiskBlockStore( Path("~/.cache/mlx-vlm/caching").expanduser(), namespace=model_id, max_bytes=3 * (1 << 30), # 3 GB disk cap; use None for uncapped ) apc = APCManager(num_blocks=4096, block_size=16, disk=disk) document = Path("long_document.txt").read_text() try: # First request computes the full prefix and stores reusable K/V blocks. prompt1 = apply_chat_template( processor, model.config, prompt=f"{document}\n\nSummarize the key decisions.", num_images=0, ) for _ in stream_generate( model, processor, prompt1, max_tokens=128, temperature=0.0, apc_manager=apc ): pass # Second request shares the same document prefix and only prefills the suffix. prompt2 = apply_chat_template( processor, model.config, prompt=f"{document}\n\nList the open engineering risks.", num_images=0, ) for chunk in stream_generate( model, processor, prompt2, max_tokens=128, temperature=0.0, apc_manager=apc ): print(chunk.text, end="", flush=True) print(apc.stats_snapshot()) finally: apc.close() ``` To compare cold, warm-memory, warm-disk, and disk-eviction behavior with a model, use the same direct API path: ```python import os import tempfile import time from pathlib import Path from mlx_vlm import load, stream_generate from mlx_vlm.apc import APCManager, DiskBlockStore from mlx_vlm.prompt_utils import apply_chat_template model_id = "Qwen/Qwen3-VL-4B-Instruct" contexts = [8000, 20000, 50000, 100000] disk_cap_gb = 0 # 0 means uncapped shard_max_blocks = 256 context_sweep_max_tokens = 1 # one token is enough to measure prefill reuse test_prompt_tokens = 8000 fill_prompts = 80 eviction_disk_cap_gb = 3.0 os.environ["APC_DISK_SHARD_MAX_BLOCKS"] = str(shard_max_blocks) model, processor = load(model_id) tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor def disk_cap_bytes(gb: float): return None if gb <= 0 else int(gb * (1 << 30)) def make_context(target_tokens: int, seed: int = 0) -> str: line = ( f"Document {seed}: APC benchmark content with deterministic facts, " "dates, identifiers, and repeated technical notes.\n" ) line_tokens = max(1, len(tokenizer.encode(line, add_special_tokens=False))) text = line * max(1, target_tokens // line_tokens) while len(tokenizer.encode(text, add_special_tokens=False)) < target_tokens: text += line return text def make_prompt(context: str, question: str) -> str: return apply_chat_template( processor, model.config, prompt=f"{context}\n\n{question}", num_images=0, ) def run_once(apc: APCManager, context: str, question: str, max_tokens: int = 32): prompt = make_prompt(context, question) apc.reset_stats() last = None output = [] start = time.perf_counter() for chunk in stream_generate( model, processor, prompt, max_tokens=max_tokens, temperature=0.0, apc_manager=apc, ): output.append(chunk.text) last = chunk if last is None: raise RuntimeError("generation returned no chunks") return { "wall_s": time.perf_counter() - start, "prompt_tokens": last.prompt_tokens, "prompt_tps": last.prompt_tps, "generation_tps": last.generation_tps, "apc": apc.stats_snapshot(), "text": "".join(output).strip(), } def print_result(label: str, result: dict) -> None: stats = result["apc"] print( f"{label:<12} " f"prompt_tokens={result['prompt_tokens']:>7} " f"prompt_tps={result['prompt_tps']:>8.1f} " f"gen_tps={result['generation_tps']:>7.1f} " f"matched={stats.get('matched_tokens', 0):>7} " f"disk_hits={stats.get('disk_hits', 0):>5} " f"disk_evictions={stats.get('disk_evictions', 0):>5}" ) def open_apc(cache_root: Path, namespace: str, disk_gb: float) -> APCManager: disk = DiskBlockStore( cache_root, namespace=namespace, max_bytes=disk_cap_bytes(disk_gb), ) return APCManager(num_blocks=4096, block_size=16, disk=disk) def run_context_sweep() -> None: print("cold / warm-memory / warm-disk") with tempfile.TemporaryDirectory() as tmp: cache_root = Path(tmp) for target_tokens in contexts: context = make_context(target_tokens) namespace = f"{model_id}-context-{target_tokens}" apc = open_apc(cache_root, namespace, disk_cap_gb) try: print(f"\ncontext ~= {target_tokens} text tokens") print_result( "cold", run_once( apc, context, "Summarize the key decisions.", max_tokens=context_sweep_max_tokens, ), ) print_result( "warm-memory", run_once( apc, context, "List the open engineering risks.", max_tokens=context_sweep_max_tokens, ), ) finally: # Closing waits for queued disk writes before reopening the disk tier. apc.close() apc = open_apc(cache_root, namespace, disk_cap_gb) try: print_result( "warm-disk", run_once( apc, context, "Extract the implementation timeline.", max_tokens=context_sweep_max_tokens, ), ) finally: apc.close() def run_disk_eviction_workload() -> None: print("\ndisk eviction workload") with tempfile.TemporaryDirectory() as tmp: cache_root = Path(tmp) namespace = f"{model_id}-eviction" test_context = make_context(test_prompt_tokens, seed=0) apc = open_apc(cache_root, namespace, eviction_disk_cap_gb) try: print_result( "seed", run_once(apc, test_context, "Summarize the retained test prefix."), ) finally: apc.close() apc = open_apc(cache_root, namespace, eviction_disk_cap_gb) try: for i in range(fill_prompts): fill_context = make_context(test_prompt_tokens, seed=i + 1) run_once( apc, fill_context, f"Summarize filler document {i + 1}.", max_tokens=1, ) if (i + 1) % 10 == 0: stats = apc.stats_snapshot() print( f"filled={i + 1:>3} " f"disk_gb={stats.get('disk_bytes', 0) / (1 << 30):.2f} " f"disk_evictions={stats.get('disk_evictions', 0)}" ) finally: apc.close() apc = open_apc(cache_root, namespace, eviction_disk_cap_gb) try: print_result( "post-fill", run_once( apc, test_context, "Check whether the retained test prefix still restores.", ), ) finally: apc.close() run_context_sweep() run_disk_eviction_workload() ``` #### Server Enable in-memory APC for the server with environment variables: ```sh APC_ENABLED=1 \ APC_NUM_BLOCKS=4096 \ mlx_vlm.server --model Qwen/Qwen3-VL-4B-Instruct --port 8080 ``` Enable the persistent disk tier: ```sh APC_ENABLED=1 \ APC_NUM_BLOCKS=4096 \ APC_DISK_PATH=~/.cache/mlx-vlm/caching \ APC_DISK_MAX_GB=3 \ APC_DISK_SHARD_MAX_BLOCKS=256 \ mlx_vlm.server --model Qwen/Qwen3-VL-4B-Instruct --port 8080 ``` Repeated requests with the same long prefix will hit APC automatically: ```sh curl -X POST "http://localhost:8080/v1/chat/completions" \ -H "Content-Type: application/json" \ -H "X-APC-Tenant: demo" \ -d '{ "model": "Qwen/Qwen3-VL-4B-Instruct", "messages": [{ "role": "user", "content": "Paste a long shared document here.\n\nNow answer question A." }], "max_tokens": 128 }' ``` Use the same `X-APC-Tenant` value for requests that may share cached prefixes. Use different tenant values to isolate cache entries between users or workspaces. Inspect and reset APC state: ```sh curl http://localhost:8080/v1/cache/stats curl -X POST http://localhost:8080/v1/cache/reset ``` Common APC environment variables: | Variable | Default | Description | |----------|---------|-------------| | `APC_ENABLED` | `0` | Set to `1` to enable APC | | `APC_NUM_BLOCKS` | `2048` | Number of in-memory APC blocks | | `APC_BLOCK_SIZE` | `16` | Tokens per APC block | | `APC_DISK_PATH` | unset | Directory for persistent disk shards | | `APC_DISK_MAX_GB` | `0` | Disk cap in GB; `0` means uncapped | | `APC_DISK_SHARD_MAX_BLOCKS` | `256` | Max blocks per disk segment shard | | `APC_MAX_POOL_TENSORS` | `450000` | Stops adding memory blocks before the Metal resource limit; disk writes continue | | `APC_LAYER_MAJOR_MEMORY_MIN_TOKENS` | `50000` | Store long warm-memory prefixes as compact layer-major snapshots instead of per-block tensors | | `APC_HASH` | `fast` | Set to `sha256` for a stable cryptographic hash | APC is disabled automatically for models that use a custom cache layout. On the server, APC is also skipped when KV-cache quantization is enabled. #### KV Cache Quantization Reduce KV cache memory during continuous batching with `--kv-bits`. Both uniform quantization and TurboQuant are supported: ```sh # Uniform 8-bit KV cache quantization mlx_vlm.server --model google/gemma-4-26b-a4b-it --kv-bits 8 # TurboQuant 3.5-bit (3-bit keys + 4-bit values) mlx_vlm.server --model google/gemma-4-26b-a4b-it --kv-bits 3.5 --kv-quant-scheme turboquant ``` Full-attention layers use quantized batch caches while sliding-window layers keep their fixed-size rotating caches. The last full-attention layer stays unquantized (sensitive in deep models). Tested with gemma-4-26b-a4b-it at 20K context: | Config | Gen tok/s | KV Cache | KV Reduction | |--------|-----------|----------|--------------| | No quant | 50.3 | 0.624 GB | 1x | | Uniform 8-bit | 52.6 | 0.469 GB | **1.33x** | | TurboQuant 3.5-bit | 25.6 | 0.365 GB | **1.71x** | > Models with all full-attention layers (e.g. Qwen, LLaMA) see larger reductions — up to 3.6x at 8-bit and 6.4x at 4-bit. #### Log Probabilities The `/chat/completions` endpoint supports OpenAI-compatible per-token log probabilities. Pass `logprobs: true` (and optionally `top_logprobs: N`, up to 20) in the request: ```sh curl -X POST "http://localhost:8080/v1/chat/completions" \ -H "Content-Type: application/json" \ -d '{ "model": "mlx-community/Qwen2-VL-2B-Instruct-4bit", "messages": [{"role":"user","content":"Say hi in 3 words."}], "max_tokens": 8, "logprobs": true, "top_logprobs": 3 }' ``` Each choice gets a `logprobs.content[]` list with one entry per generated token: `{token, logprob, bytes, top_logprobs: [{token, logprob, bytes}, ...]}`. Works for both streaming and non-streaming. `top_logprobs` requires the server to be started with a non-zero cap on how many alternatives it will compute per token (default `0` = disabled, max `20`). Set it via the `--top-logprobs-k` flag or the `TOP_LOGPROBS_K` env var: ```sh mlx_vlm.server --model mlx-community/Qwen2-VL-2B-Instruct-4bit --top-logprobs-k 5 # or TOP_LOGPROBS_K=5 mlx_vlm.server --model mlx-community/Qwen2-VL-2B-Instruct-4bit ``` Per-request `top_logprobs` is clamped to `TOP_LOGPROBS_K`. When `TOP_LOGPROBS_K=0`, requests with `logprobs: true` still return chosen-token logprobs; only the `top_logprobs` list stays empty. Leaving the cap at `0` keeps the vocab-wide sort out of the decode graph, so deployments that don't need logprobs pay zero overhead. #### Structured Outputs The `/v1/chat/completions` and `/v1/responses` endpoints support OpenAI-compatible `json_schema` structured outputs. The server constrains generation to the supplied JSON schema and supports both streaming and non-streaming responses. You can define the schema with Pydantic: ```python from typing import Literal from pydantic import BaseModel, ConfigDict, Field class AnimalResult(BaseModel): model_config = ConfigDict(extra="forbid") animal: Literal["dog", "cat", "bird", "unknown"] species: str = Field(max_length=60) description: str = Field(max_length=200) schema = AnimalResult.model_json_schema() ``` Call the local server with the OpenAI Python client: ```python from openai import OpenAI client = OpenAI(base_url="http://localhost:8080/v1", api_key="not-needed") response = client.chat.completions.create( model="mlx-community/Qwen3.5-4B-MLX-4bit", messages=[ {"role": "user", "content": "Return a dog object."}, ], response_format={ "type": "json_schema", "json_schema": { "name": "AnimalResult", "strict": True, "schema": schema, }, }, ) result = AnimalResult.model_validate_json(response.choices[0].message.content) print(result) ``` Example output: ```text animal='dog' species='Canis lupus familiaris' description='A domesticated canine known for companionship and loyalty.' ``` Chat completions use top-level `response_format`. The same format works for text-only and multimodal requests: ```sh curl -X POST "http://localhost:8080/v1/chat/completions" \ -H "Content-Type: application/json" \ -d '{ "model": "mlx-community/Qwen3.5-4B-MLX-4bit", "messages": [{ "role": "user", "content": [ {"type": "text", "text": "Identify the main animal in this image."}, {"type": "image_url", "image_url": {"url": "/path/to/image.jpg"}} ] }], "response_format": { "type": "json_schema", "json_schema": { "name": "AnimalResult", "strict": true, "schema": { "type": "object", "properties": { "animal": {"type": "string", "enum": ["dog", "cat", "bird", "unknown"]}, "species": {"type": "string", "maxLength": 60}, "description": {"type": "string", "maxLength": 200} }, "required": ["animal", "species", "description"], "additionalProperties": false } } }, "max_tokens": 256 }' ``` Structured outputs are also supported with: - Streaming chat completions by setting `"stream": true` - The responses API via `text.format` on `/v1/responses` - Text-only requests using the same `response_format` shape Structured outputs are not currently supported with speculative decoding. #### How It Works - A dedicated generation thread runs a `BatchGenerator` that processes multiple requests in parallel - Image requests are prefilled individually with their own vision embeddings, then join the shared decoding batch - Text-only requests are batched together for efficient prefill - After prefill, all requests decode together in a single batch, sharing GPU compute #### Available Endpoints - `/models` and `/v1/models` - List models available locally - `/chat/completions` and `/v1/chat/completions` - OpenAI-compatible chat-style interaction endpoint with support for images, audio, and text - `/responses` and `/v1/responses` - OpenAI-compatible responses endpoint - `/health` - Check server status - `/metrics` and `/v1/metrics` - Inspect rolling request metrics, throughput, and runtime counters - `/unload` - Unload current model from memory #### Usage Examples ##### List available models ```sh curl "http://localhost:8080/models" ``` ##### Text Input ```sh curl -X POST "http://localhost:8080/chat/completions" \ -H "Content-Type: application/json" \ -d '{ "model": "mlx-community/Qwen2-VL-2B-Instruct-4bit", "messages": [ { "role": "user", "content": "Hello, how are you" } ], "stream": true, "max_tokens": 100 }' ``` ##### Image Input ```sh curl -X POST "http://localhost:8080/chat/completions" \ -H "Content-Type: application/json" \ -d '{ "model": "mlx-community/Qwen2.5-VL-32B-Instruct-8bit", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": [ { "type": "text", "text": "This is today's chart for energy demand in California. Can you provide an analysis of the chart and comment on the implications for renewable energy in California?" }, { "type": "input_image", "image_url": "/path/to/repo/examples/images/renewables_california.png" } ] } ], "stream": true, "max_tokens": 1000 }' ``` ##### Audio Support (New) ```sh curl -X POST "http://localhost:8080/generate" \ -H "Content-Type: application/json" \ -d '{ "model": "mlx-community/gemma-3n-E2B-it-4bit", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe what you hear in these audio files" }, { "type": "input_audio", "input_audio": "/path/to/audio1.wav" }, { "type": "input_audio", "input_audio": "https://example.com/audio2.mp3" } ] } ], "stream": true, "max_tokens": 500 }' ``` ##### Multi-Modal (Image + Audio) ```sh curl -X POST "http://localhost:8080/generate" \ -H "Content-Type: application/json" \ -d '{ "model": "mlx-community/gemma-3n-E2B-it-4bit", "messages": [ { "role": "user", "content": [ {"type": "input_image", "image_url": "/path/to/image.jpg"}, {"type": "input_audio", "input_audio": "/path/to/audio.wav"} ] } ], "max_tokens": 100 }' ``` ##### Responses Endpoint ```sh curl -X POST "http://localhost:8080/responses" \ -H "Content-Type: application/json" \ -d '{ "model": "mlx-community/Qwen2-VL-2B-Instruct-4bit", "messages": [ { "role": "user", "content": [ {"type": "input_text", "text": "What is in this image?"}, {"type": "input_image", "image_url": "/path/to/image.jpg"} ] } ], "max_tokens": 100 }' ``` #### Request Parameters - `model`: Model identifier (required) - `messages`: Chat messages for chat/OpenAI endpoints - `max_tokens`: Maximum tokens to generate - `temperature`: Sampling temperature - `top_p`: Top-p sampling parameter - `top_k`: Top-k sampling cutoff - `min_p`: Min-p sampling threshold - `repetition_penalty`: Penalty applied to repeated tokens - `enable_thinking`: Override the server thinking-mode default for a request (`true` or `false`) - `thinking_budget`: Maximum tokens allowed inside the thinking block - `thinking_start_token`: Token that opens a thinking block - `thinking_end_token`: Token that closes a thinking block - `stream`: Enable streaming responses ## Activation Quantization (CUDA) When running on NVIDIA GPUs with MLX CUDA, models quantized with `mxfp8` or `nvfp4` modes require activation quantization to work properly. This converts `QuantizedLinear` layers to `QQLinear` layers which quantize both weights and activations. ### Command Line Use the `-qa` or `--quantize-activations` flag: ```sh mlx_vlm.generate --model /path/to/mxfp8-model --prompt "Describe this image" --image /path/to/image.jpg -qa ``` ### Python API Pass `quantize_activations=True` to the `load` function: ```python from mlx_vlm import load, generate # Load with activation quantization enabled model, processor = load( "path/to/mxfp8-quantized-model", quantize_activations=True ) # Generate as usual output = generate(model, processor, "Describe this image", image=["image.jpg"]) ``` ### Supported Quantization Modes - `mxfp8` - 8-bit MX floating point - `nvfp4` - 4-bit NVIDIA floating point > **Note**: This feature is required for mxfp/nvfp quantized models on CUDA. On Apple Silicon (Metal), these models work without the flag. ## Multi-Image Chat Support MLX-VLM supports analyzing multiple images simultaneously with select models. This feature enables more complex visual reasoning tasks and comprehensive analysis across multiple images in a single conversation. ### Usage Examples #### Python Script ```python from mlx_vlm import load, generate from mlx_vlm.prompt_utils import apply_chat_template from mlx_vlm.utils import load_config model_path = "mlx-community/Qwen2-VL-2B-Instruct-4bit" model, processor = load(model_path) config = model.config images = ["path/to/image1.jpg", "path/to/image2.jpg"] prompt = "Compare these two images." formatted_prompt = apply_chat_template( processor, config, prompt, num_images=len(images) ) output = generate(model, processor, formatted_prompt, images, verbose=False) print(output) ``` #### Command Line ```sh mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --prompt "Compare these images" --image path/to/image1.jpg path/to/image2.jpg ``` ## Video Understanding MLX-VLM also supports video analysis such as captioning, summarization, and more, with select models. ### Supported Models The following models support video chat: 1. Qwen2-VL 2. Qwen2.5-VL 3. Idefics3 4. LLaVA With more coming soon. ### Usage Examples #### Command Line ```sh mlx_vlm.video_generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --prompt "Describe this video" --video path/to/video.mp4 --max-pixels 224 224 --fps 1.0 ``` These examples demonstrate how to use multiple images with MLX-VLM for more complex visual reasoning tasks. ## Vision Feature Caching In multi-turn conversations about an image, the vision encoder runs on every turn even though the image hasn't changed. `VisionFeatureCache` stores projected vision features in an LRU cache keyed by image path, so the expensive vision encoder is only called once per unique image. ### How It Works 1. **First turn (cache miss)** -- `encode_image()` runs the full vision pipeline (vision tower + projector), stores the result in the cache, and passes it to the language model. 2. **Subsequent turns (cache hit)** -- the cached features are passed directly via `cached_image_features`, skipping the vision encoder entirely. 3. **Image switch** -- when the image changes, it's a new cache key so features are computed and cached. Switching back to a previous image is a cache hit. The cache holds up to 8 entries (configurable) and uses LRU eviction. ### CLI All chat interfaces use `VisionFeatureCache` automatically: ```sh # Gradio chat UI python -m mlx_vlm.chat_ui --model google/gemma-4-26b-a4b-it # Interactive chat with Rich UI (load images with /image command) python -m mlx_vlm.chat --model google/gemma-4-26b-a4b-it # Inline chat mode python -m mlx_vlm.generate \ --model google/gemma-4-26b-a4b-it \ --image path/to/image.jpg \ --chat \ --max-tokens 200 ``` ### Python ```python from mlx_vlm import load, stream_generate, VisionFeatureCache from mlx_vlm.prompt_utils import apply_chat_template model, processor = load("google/gemma-4-26b-a4b-it") cache = VisionFeatureCache() image = "path/to/image.jpg" # Turn 1 -- cache miss, encodes image prompt1 = apply_chat_template(processor, model.config, "Describe this image.", num_images=1) for chunk in stream_generate(model, processor, prompt1, image=[image], max_tokens=200, vision_cache=cache): print(chunk.text, end="") # Turn 2 -- cache hit, skips vision encoder prompt2 = apply_chat_template(processor, model.config, "What colors do you see?", num_images=1) for chunk in stream_generate(model, processor, prompt2, image=[image], max_tokens=200, vision_cache=cache): print(chunk.text, end="") ``` ### Server The server caches vision features automatically across requests for the same image. No configuration needed -- the cache is created when a model loads and cleared on unload. ```sh mlx_vlm.server --model google/gemma-4-26b-a4b-it ``` Multi-turn conversations via `/v1/chat/completions` (streaming and non-streaming) and `/responses` all benefit. The same image sent across multiple requests will only be encoded once. ### Performance Tested on `google/gemma-4-26b-a4b-it` over 10 multi-turn conversation turns: | Metric | Without Cache | With Cache | |--------|--------------|------------| | Prompt TPS | ~48 | ~550-825 | | Speedup | -- | **11x+** | | Peak Memory | 52.66 GB | 52.66 GB (flat) | Generation speed (~31 tok/s) and memory are unaffected -- only prompt processing gets faster. ## TurboQuant KV Cache TurboQuant compresses the KV cache during generation, enabling longer context lengths with less memory while maintaining quality. ### Quick Start ```sh # 3.5-bit KV cache quantization (3-bit keys + 4-bit values) mlx_vlm generate \ --model mlx-community/Qwen3.5-4B-4bit \ --kv-bits 3.5 \ --kv-quant-scheme turboquant \ --prompt "Your long prompt here..." ``` ```python from mlx_vlm import generate result = generate( model, processor, prompt, kv_bits=3.5, kv_quant_scheme="turboquant", max_tokens=256, ) ``` ```sh # Server with TurboQuant mlx_vlm server \ --model google/gemma-4-26b-a4b-it \ --kv-bits 3.5 \ --kv-quant-scheme turboquant ``` ### How It Works TurboQuant uses random rotation + codebook quantization ([arXiv:2504.19874](https://arxiv.org/abs/2504.19874)) to compress KV cache entries from 16-bit to 2-4 bits per dimension: - **Keys & Values**: MSE codebook quantization with Hadamard rotation - **Fractional bits** (e.g. 3.5): uses lower bits for keys, higher for values (3-bit K + 4-bit V) Custom Metal kernels fuse score computation and value aggregation directly on packed quantized data, avoiding full dequantization during decode. ### Performance Tested on Qwen3.5-4B-4bit at 128k context: | Metric | Baseline | TurboQuant 3.5-bit | |--------|----------|-------------------| | KV Memory | 4.1 GB | 0.97 GB (**76% reduction**) | | Peak Memory | 18.3 GB | 17.3 GB (**-1.0 GB**) | At 512k+ contexts, TurboQuant's per-layer attention is **faster than FP16 SDPA** due to reduced memory bandwidth requirements. Tested on gemma-4-31b-it at 128k context: | Metric | Baseline | TurboQuant 3.5-bit | |--------|----------|-------------------| | KV Memory | 13.3 GB | 4.9 GB (**63% reduction**) | | Peak Memory | 75.2 GB | 65.8 GB (**-9.4 GB**) | ### Supported Bit Widths | Bits | Compression | Best For | |------|------------|----------| | 2 | ~8x | Maximum compression, some quality loss | | 3 | ~5x | Good balance of quality and compression | | 3.5 | ~4.5x | Recommended default (3-bit keys + 4-bit values) | | 4 | ~4x | Best quality, moderate compression | ### Compatibility TurboQuant automatically quantizes `KVCache` layers (global attention). Models with `RotatingKVCache` (sliding window) or `ArraysCache` (MLA/absorbed keys) keep their native cache format for those layers since they are already memory-efficient. TurboQuant is supported in both single-request generation and continuous batching on the server. In continuous batching mode, KV states are stored in TurboQuant's compressed format and dequantized at attention time (custom Metal kernels are not yet batch-aware). ## Distributed Inference mlx-vlm supports distributed inference across multiple computers. It works by sharding the language model (not the vision tower), because the LLM is much larger and vision embeddings only need to be computed once. The parallel implementation is compatible with [mlx-lm](https://github.com/ml-explore/mlx-lm) sharding primitives. See [docs/usage.md](https://github.com/Blaizzy/mlx-vlm/blob/main/docs/usage.md#distributed-inference) for command-line examples. # Fine-tuning MLX-VLM supports fine-tuning models with LoRA and QLoRA. ## LoRA & QLoRA To learn more about LoRA, please refer to the [LoRA.md](./mlx_vlm/LORA.MD) file.

Health & Fitness ML Frameworks
5K Github Stars
mlx-embeddings
Open Source

mlx-embeddings

# MLX-Embeddings [![image](https://img.shields.io/pypi/v/mlx-embeddings.svg)](https://pypi.python.org/pypi/mlx-embeddings) [![Upload Python Package](https://github.com/Blaizzy/mlx-embeddings/actions/workflows/python-publish.yaml/badge.svg)](https://github.com/Blaizzy/mlx-embeddings/actions/workflows/python-publish.yaml) **MLX-Embeddings is a package for running Vision and Language Embedding models locally on your Mac using MLX.** - Free software: GNU General Public License v3 ## Features - Generate embeddings for text and images using MLX models - Support for single-item and batch processing - Utilities for comparing text similarities ## Supported Models Archictectures MLX-Embeddings supports a variety of model architectures for text embedding tasks. Here's a breakdown of the currently supported architectures: - XLM-RoBERTa (Cross-lingual Language Model - Robustly Optimized BERT Approach) - BERT (Bidirectional Encoder Representations from Transformers) - ModernBERT (modernized bidirectional encoder-only Transformer model) - Qwen3 (Qwen3's embedding model) - Qwen3-VL (multimodal Qwen3-VL embedding and reranking model) - Llama Bidirectional (Llama-based bidirectional embedding models, e.g. NVIDIA NV-Embed) - Llama Nemotron VL (multimodal vision-language embedding model with SigLIP vision + bidirectional Llama) - OpenAI Privacy Filter (bidirectional GPT-OSS variant for PII token classification with sparse MoE, GQA + attention sinks, and YARN RoPE) We're continuously working to expand our support for additional model architectures. Check our GitHub repository or documentation for the most up-to-date list of supported models and their specific versions. ## Installation You can install mlx-embeddings using pip: ```bash pip install mlx-embeddings ``` ## Usage ### Qwen3-VL Qwen3-VL uses a model-specific processor and a high-level `model.process(...)` API for multimodal embedding and reranking. #### Multimodal Embedding ```python import mlx.core as mx from mlx_embeddings import load model, processor = load("Qwen/Qwen3-VL-Embedding-2B") inputs = [ { "text": "A woman playing with her dog on a beach at sunset.", "instruction": "Retrieve images or text relevant to the user's query.", }, { "text": "A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset." }, { "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" }, { "text": "A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset.", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", }, ] embeddings = model.process(inputs, processor=processor) similarity = embeddings @ embeddings.T mx.eval(embeddings, similarity) print(embeddings.shape) # (4, 2048) print(similarity) ``` #### Multimodal Reranking ```python import mlx.core as mx from mlx_embeddings import load model, processor = load("Qwen/Qwen3-VL-Reranker-2B") inputs = { "instruction": "Retrieve images or text relevant to the user's query.", "query": {"text": "A woman playing with her dog on a beach at sunset."}, "documents": [ { "text": "A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset." }, { "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" }, { "text": "A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset.", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", }, ], } scores = model.process(inputs, processor=processor) mx.eval(scores) print(scores.shape) # (3,) print(scores) ``` ### Single Item Embedding #### Text Embedding To generate an embedding for a single piece of text: ```python from mlx_embeddings.utils import load # Load the model and tokenizer model_name = "mlx-community/all-MiniLM-L6-v2-4bit" model, tokenizer = load(model_name) # Prepare the text text = "I like reading" # Tokenize and generate embedding input_ids = tokenizer.encode(text, return_tensors="mlx") outputs = model(input_ids) raw_embeds = outputs.last_hidden_state[:, 0, :] # CLS token text_embeds = outputs.text_embeds # mean pooled and normalized embeddings ``` Note : text-embeds use mean pooling for bert and xlm-robert. For modernbert, pooling strategy is set through the config file, defaulting to mean #### Masked Language Modeling To generate embeddings for masked language modeling tasks: ```python from mlx_embeddings.utils import load # Load ModernBERT model and tokenizer model, tokenizer = load("mlx-community/answerdotai-ModernBERT-base-4bit") # Masked Language Modeling example text = "The capital of France is [MASK]." inputs = tokenizer.encode(text, return_tensors="mlx") outputs = model(inputs) # Get predictions for the masked token masked_index = inputs.tolist()[0].index(tokenizer.mask_token_id) predicted_token_id = mx.argmax(outputs.pooler_output[0, masked_index]).tolist() predicted_token = tokenizer.decode(predicted_token_id) print("Predicted token:", predicted_token) # Should output: Paris ``` #### Sequence classification ```python from mlx_embeddings.utils import load # Load ModernBERT model and tokenizer model, tokenizer = load( "NousResearch/Minos-v1", ) id2label=model.config.id2label # Masked Language Modeling example text = "<|user|> Explain the theory of relativity in simple terms. <|assistant|> Imagine space and time are like a stretchy fabric. Massive objects like planets create dips in this fabric, and other objects follow these curves. That's gravity! Also, the faster you move, the slower time passes for you compared to someone standing still" inputs = tokenizer.encode(text, return_tensors="mlx") outputs = model(inputs) # Get predictions for the masked token predictions = outputs.pooler_output[0] # Shape: (num_label,) print(text) # Print results print("\nTop predictions for classification:") for idx, logit in enumerate(predictions.tolist()): label = id2label[str(idx)] print(f"{label}: {logit:.3f}") ``` #### Token Classification (PII detection) `openai/privacy-filter` is a bidirectional 1.5B-parameter / 50M-active sparse-MoE token classifier that tags personally identifiable information (PII) with BIOES spans over 8 categories (person, email, phone, URL, address, date, account number, secret). ```python from itertools import groupby import mlx.core as mx from mlx_embeddings.utils import load model, tokenizer = load("openai/privacy-filter") id2label = model.config.id2label text = "My name is Alice Smith and my email is [email protected]. Phone: 555-1234." inputs = tokenizer(text, return_tensors="mlx") outputs = model(inputs["input_ids"], attention_mask=inputs["attention_mask"]) preds = mx.argmax(outputs.logits, axis=-1)[0].tolist() entity = lambda p: id2label[str(p)].split("-", 1)[-1] if id2label[str(p)] != "O" else None for ent, group in groupby(zip(inputs["input_ids"][0].tolist(), preds), key=lambda x: entity(x[1])): if ent: span = tokenizer.decode([tid for tid, _ in group]).strip() print(f"{ent:18s} -> {span!r}") ``` ### Batch Processing #### Multiple Texts Comparison To embed multiple texts and compare them using their embeddings: ```python from sklearn.metrics.pairwise import cosine_similarity import matplotlib.pyplot as plt import seaborn as sns import mlx.core as mx from mlx_embeddings.utils import load # Load the model and tokenizer model, tokenizer = load("mlx-community/all-MiniLM-L6-v2-4bit") def get_embedding(texts, model, tokenizer): inputs = tokenizer.batch_encode_plus(texts, return_tensors="mlx", padding=True, truncation=True, max_length=512) outputs = model( inputs["input_ids"], attention_mask=inputs["attention_mask"] ) return outputs.text_embeds # mean pooled and normalized embeddings def compute_and_print_similarity(embeddings): B, _ = embeddings.shape similarity_matrix = cosine_similarity(embeddings) print("Similarity matrix between sequences:") print(similarity_matrix) print("\n") for i in range(B): for j in range(i+1, B): print(f"Similarity between sequence {i+1} and sequence {j+1}: {similarity_matrix[i][j]:.4f}") return similarity_matrix # Visualize results def plot_similarity_matrix(similarity_matrix, labels): plt.figure(figsize=(5, 4)) sns.heatmap(similarity_matrix, annot=True, cmap='coolwarm', xticklabels=labels, yticklabels=labels) plt.title('Similarity Matrix Heatmap') plt.tight_layout() plt.show() # Sample texts texts = [ "I like grapes", "I like fruits", "The slow green turtle crawls under the busy ant." ] embeddings = get_embedding(texts, model, tokenizer) similarity_matrix = compute_and_print_similarity(embeddings) # Visualize results labels = [f"Text {i+1}" for i in range(len(texts))] plot_similarity_matrix(similarity_matrix, labels) ``` #### Masked Language Modeling To get predictions for the masked token in multiple texts: ```python import mlx.core as mx from mlx_embeddings.utils import load # Load the model and tokenizer model, tokenizer = load("mlx-community/answerdotai-ModernBERT-base-4bit") text = ["The capital of France is [MASK].", "The capital of Poland is [MASK]."] inputs = tokenizer.batch_encode_plus(text, return_tensors="mlx", padding=True, truncation=True, max_length=512) outputs = model(**inputs) # To get predictions for the mask: # Find mask token indices for each sequence in the batch # Find mask indices for all sequences in batch mask_indices = mx.array([ids.tolist().index(tokenizer.mask_token_id) for ids in inputs["input_ids"]]) # Get predictions for all masked tokens at once batch_indices = mx.arange(len(mask_indices)) predicted_token_ids = mx.argmax(outputs.pooler_output[batch_indices, mask_indices], axis=-1).tolist() # Decode the predicted tokens predicted_token = tokenizer.batch_decode(predicted_token_ids) print("Predicted token:", predicted_token) # Predicted token: Paris, Warsaw ``` ## Vision Transformer Models MLX-Embeddings also supports vision models that can generate embeddings for images or image-text pairs. ### Single Image Processing To evaluate how well an image matches different text descriptions: ```python import mlx.core as mx from mlx_embeddings.utils import load import requests from PIL import Image # Load vision model and processor model, processor = load("mlx-community/siglip-so400m-patch14-384") # Load an image url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) # Create text descriptions to compare with the image texts = ["a photo of 2 dogs", "a photo of 2 cats"] # Process inputs inputs = processor(text=texts, images=image, padding="max_length", return_tensors="np") pixel_values = mx.array(inputs.pixel_values).transpose(0, 2, 3, 1).astype(mx.float32) input_ids = mx.array(inputs.input_ids) # Generate embeddings and calculate similarity outputs = model(pixel_values=pixel_values, input_ids=input_ids) logits_per_image = outputs.logits_per_image probs = mx.sigmoid(logits_per_image) # probabilities of image matching each text # Print results print(f"{probs[0][0]:.1%} that image matches '{texts[0]}'") print(f"{probs[0][1]:.1%} that image matches '{texts[1]}'") ``` ### Batch Processing for Multiple Images comparison Process multiple images and compare them with text descriptions: ```python import mlx.core as mx from mlx_embeddings.utils import load import requests from PIL import Image import matplotlib.pyplot as plt import seaborn as sns # Load vision model and processor model, processor = load("mlx-community/siglip-so400m-patch14-384") # Load multiple images image_urls = [ "./images/cats.jpg", # cats "./images/desktop_setup.png" # desktop setup ] images = [Image.open(requests.get(url, stream=True).raw) if url.startswith("http") else Image.open(url) for url in image_urls] # Text descriptions texts = ["a photo of cats", "a photo of a desktop setup", "a photo of a person"] # Process all image-text pairs all_probs = [] # Process all image-text pairs in batch inputs = processor(text=texts, images=images, padding="max_length", return_tensors="np") pixel_values = mx.array(inputs.pixel_values).transpose(0, 2, 3, 1).astype(mx.float32) input_ids = mx.array(inputs.input_ids) # Generate embeddings and calculate similarity outputs = model(pixel_values=pixel_values, input_ids=input_ids) logits_per_image = outputs.logits_per_image probs = mx.sigmoid(logits_per_image) # probabilities for this image all_probs.append(probs.tolist()) # Print results for this image for i, image in enumerate(images): print(f"Image {i+1}:") for j, text in enumerate(texts): print(f" {probs[i][j]:.1%} match with '{text}'") print() # Visualize results with a heatmap def plot_similarity_matrix(probs_matrix, image_labels, text_labels): # Convert to 2D numpy array if needed import numpy as np probs_matrix = np.array(probs_matrix) # Ensure we have a 2D matrix for the heatmap if probs_matrix.ndim > 2: probs_matrix = probs_matrix.squeeze() plt.figure(figsize=(8, 5)) sns.heatmap(probs_matrix, annot=True, cmap='viridis', xticklabels=text_labels, yticklabels=image_labels, fmt=".1%", vmin=0, vmax=1) plt.title('Image-Text Match Probability') plt.tight_layout() plt.show() # Plot the images for reference plt.figure(figsize=(8, 5)) for i, image in enumerate(images): plt.subplot(1, len(images), i+1) plt.imshow(image) plt.title(f"Image {i+1}") plt.axis('off') plt.tight_layout() plt.show() image_labels = [f"Image {i+1}" for i in range(len(images))] plot_similarity_matrix(all_probs, image_labels, texts) ``` ### Late Interaction Multimodal Retrieval Models (ColPali/ColQwen) ```python import mlx.core as mx import requests from io import BytesIO from PIL import Image from transformers import AutoImageProcessor from mlx_embeddings import load from mlx_embeddings.models.base import normalize_embeddings # Load the model and tokenizer returned by mlx-embeddings model, tokenizer = load("qnguyen3/colqwen2.5-v0.2-mlx") image_processor = AutoImageProcessor.from_pretrained("qnguyen3/colqwen2.5-v0.2-mlx") def fetch_image(url): response = requests.get(url, headers={"User-Agent": "Mozilla/5.0"}, timeout=60) response.raise_for_status() return Image.open(BytesIO(response.content)).convert("RGB") def nonpad_rows(embeds, attention_mask): indices = [i for i, value in enumerate(attention_mask[0].tolist()) if value != 0] return embeds[0, indices, :] def prepare_query(text): suffix = tokenizer.pad_token * 10 query = "Query: " + text + suffix inputs = tokenizer([query], return_tensors="np", padding=True) return { "input_ids": mx.array(inputs["input_ids"]), "attention_mask": mx.array(inputs["attention_mask"]), } def prepare_image(image): image_inputs = image_processor( images=[image], return_tensors="np", data_format="channels_first", do_convert_rgb=True, ) image_grid_thw = mx.array(image_inputs["image_grid_thw"]) num_image_tokens = int( image_inputs["image_grid_thw"][0].prod() // (image_processor.merge_size ** 2) ) prompt = ( "<|im_start|>user\n" "<|vision_start|><|image_pad|><|vision_end|>" "Describe the image.<|im_end|><|endoftext|>" ) prompt = prompt.replace("<|image_pad|>", "<|image_pad|>" * num_image_tokens) text_inputs = tokenizer([prompt], return_tensors="np", padding=True) return { "input_ids": mx.array(text_inputs["input_ids"]), "attention_mask": mx.array(text_inputs["attention_mask"]), "pixel_values": mx.array(image_inputs["pixel_values"]), "image_grid_thw": image_grid_thw, } def embed_query(text): inputs = prepare_query(text) inputs_embeds = model.get_input_embeddings_batch(inputs["input_ids"]) position_ids, _ = model.vlm.language_model.get_rope_index( inputs["input_ids"], attention_mask=inputs["attention_mask"], ) hidden = model.vlm.language_model.model( None, inputs_embeds=inputs_embeds, mask=None, cache=None, position_ids=position_ids, ) embeds = normalize_embeddings(model.embedding_proj_layer(hidden)) embeds = embeds * inputs["attention_mask"][:, :, None] return nonpad_rows(embeds, inputs["attention_mask"]) def embed_image(image): inputs = prepare_image(image) inputs_embeds = model.get_input_embeddings_batch( inputs["input_ids"], inputs["pixel_values"], inputs["image_grid_thw"], ) position_ids, _ = model.vlm.language_model.get_rope_index( inputs["input_ids"], image_grid_thw=inputs["image_grid_thw"], attention_mask=inputs["attention_mask"], ) hidden = model.vlm.language_model.model( None, inputs_embeds=inputs_embeds, mask=None, cache=None, position_ids=position_ids, ) embeds = normalize_embeddings(model.embedding_proj_layer(hidden)) embeds = embeds * inputs["attention_mask"][:, :, None] return nonpad_rows(embeds, inputs["attention_mask"]) def maxsim(query_embeds, image_embeds): sims = query_embeds @ image_embeds.T return mx.sum(mx.max(sims, axis=1)) texts = ["how many percent of data are books?", "evaluation results between models"] images = [ fetch_image("https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg"), fetch_image("https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg"), ] query_embeddings = [embed_query(text) for text in texts] image_embeddings = [embed_image(image) for image in images] scores = [[float(maxsim(q, d)) for d in image_embeddings] for q in query_embeddings] print([embedding.shape for embedding in query_embeddings]) print([embedding.shape for embedding in image_embeddings]) print(scores) ``` ## Model Conversion ### Converting Hugging Face Models to MLX Format You can convert Hugging Face models to MLX format using the `mlx-embeddings` conversion tool: ```bash python -m mlx_embeddings.convert \ --hf-path <huggingface-model-id-or-path> \ --mlx-path <output-path> ``` ### Quantization The conversion tool supports quantization to reduce model size and improve inference speed: ```bash # Default affine quantization (group_size=64, bits=4) python -m mlx_embeddings.convert \ --hf-path <huggingface-model-id-or-path> \ --mlx-path <output-path> \ --quantize ``` #### Quantization Modes The `--q-mode` option specifies which quantization mode to use. Supported modes are: | Mode | Group Size | Bits | Use Case | |------|-----------|------|----------| | `affine` (default) | 64 | 4 | General-purpose quantization | | `mxfp4` | 32 | 4 | MLX floating-point 4-bit | | `nvfp4` | 16 | 4 | NVIDIA floating-point 4-bit | | `mxfp8` | 32 | 8 | MLX floating-point 8-bit (higher precision) | **Examples:** ```bash # mxfp4 quantization with default settings python -m mlx_embeddings.convert \ --hf-path <model> \ --mlx-path <output-path> \ --quantize \ --q-mode mxfp4 # nvfp4 quantization with custom group size and bits python -m mlx_embeddings.convert \ --hf-path <model> \ --mlx-path <output-path> \ --quantize \ --q-mode nvfp4 \ --q-group-size 32 \ --q-bits 6 # mxfp8 for higher precision (8-bit) python -m mlx_embeddings.convert \ --hf-path <model> \ --mlx-path <output-path> \ --quantize \ --q-mode mxfp8 ``` **Note:** User-specified `--q-group-size` and `--q-bits` values override mode defaults. ### Other Conversion Options - `--dtype`: Convert to specific dtype (`float16`, `bfloat16`, `float32`). Defaults to `float16`. - `--dequantize`: Dequantize a previously quantized model. - `--upload-repo`: Upload converted model to Hugging Face Hub. ## Contributing Contributions to MLX-Embeddings are welcome! Please refer to our contribution guidelines for more information. ## License This project is licensed under the GNU General Public License v3. ## Contact For any questions or issues, please open an issue on the [GitHub repository](https://github.com/Blaizzy/mlx-embeddings).

Knowledge Bases & RAG Video Conferencing
395 Github Stars