← Articles EN
vLLM · DGX Spark · Blackwell

Servir Mistral-Small-4-119B avec vLLM sur DGX Spark

Installation complète de vLLM et configuration optimisée pour servir un modèle de 119B paramètres quantifié en NVFP4 sur un GPU Blackwell consumer GB10 avec 128 GiB de mémoire unifiée.

Sébastien Burel · haruni.net · Mars 2026

Introduction

Le modèle : Mistral Small 4

mistralai/Mistral-Small-4-119B-2603-NVFP4 est un modèle hybride remarquable publié par Mistral AI en mars 2026. Il unifie en un seul point de contrôle trois familles de modèles — Instruct, Reasoning (anciennement Magistral) et Devstral — offrant ainsi une polyvalence exceptionnelle pour un usage général.

Son architecture MoE (Mixture of Experts) le rend particulièrement efficace :

Ce checkpoint en particulier est une version quantifiée en NVFP4 (post-training activation quantization), créée en collaboration avec les équipes vLLM et Red Hat via llm-compressor. Cette quantisation réduit significativement la mémoire requise tout en conservant de bonnes performances — ce qui le rend justement servable sur un DGX Spark avec 128 GiB de mémoire unifiée.

En termes de performance, Mistral Small 4 offre une réduction de 40% du temps de complétion end-to-end en configuration optimisée latence, et 3x plus de requêtes par seconde en configuration optimisée débit, par rapport à Mistral Small 3.

Le matériel : DGX Spark

Le DGX Spark est une machine compacte mais remarquable : elle embarque un GPU NVIDIA GB10 (architecture Blackwell SM121) avec 128 GiB de mémoire unifiée CPU+GPU. C'est suffisant pour faire tourner localement ce modèle de 119B paramètres quantifié, dont les poids occupent ~66 GiB en mémoire.

Cet article documente l'installation complète de vLLM et les choix de configuration pour servir ce modèle efficacement, en détaillant les problèmes rencontrés avec le GPU Blackwell consumer et les solutions trouvées.


Contexte matériel

ComposantValeur
GPUNVIDIA GB10 (Blackwell, SM121)
Architecture CUDA12.1 (12.1a pour la variante consumer)
Mémoire unifiée128 GiB
CUDA13.0
Architecture CPUaarch64
SystèmeUbuntu 24

Important : le GB10 utilise SM121 (Blackwell consumer) et non SM100 (Blackwell datacenter). Cette distinction est cruciale pour la compatibilité des kernels CUDA.


1. Téléchargement du modèle

hf download mistralai/Mistral-Small-4-119B-2603-NVFP4

Le modèle pèse ~66 GiB. Il sera stocké dans ~/.cache/huggingface/hub/.

Note stockage : si le modèle est sur un disque USB ou un SSD externe (~475 MB/s), comptez ~10 minutes de chargement à chaque démarrage. Un NVMe interne serait idéal mais le DGX Spark a un espace NVMe limité (~900 GiB dont une grande partie occupée par le système).


2. Installation de vLLM

2.1 Cloner le dépôt

Le dépôt officiel de vLLM ne supporte pas encore parfaitement le parsing Mistral v15. On utilise un fork corrigé :

git clone --branch fix_mistral_parsing https://github.com/juliendenize/vllm.git vllm-mistral
cd vllm-mistral

2.2 Créer l'environnement virtuel

uv venv
source .venv/bin/activate

2.3 Installer vLLM avec les binaires précompilés

VLLM_USE_PRECOMPILED=1 uv pip install --editable .

2.4 Installer PyTorch pour CUDA 13

uv pip install --index-url https://download.pytorch.org/whl/cu130 torch==2.10.0+cu130
uv pip install --index-url https://download.pytorch.org/whl/cu130 torchvision

2.5 Installer la dernière version de Transformers

Le tokenizer Mistral v15 nécessite une version récente de mistral_common, incluse via Transformers :

uv pip install git+https://github.com/huggingface/transformers.git
pip install --upgrade mistral_common

2.6 Installer FlashInfer

FlashInfer fournit des kernels optimisés pour l'attention. Pour CUDA 13 :

uv pip install flashinfer-python flashinfer-cubin
pip install flashinfer-jit-cache --index-url https://flashinfer.ai/whl/cu130

3. Configuration du service systemd

3.1 Fichier de service

sudo systemctl edit --force --full vllm.service
[Unit]
Description=vLLM Inference Server
After=network.target

[Service]
User=sb
Group=sb
WorkingDirectory=/mnt/data/sb/projects/vllm-mistral
Environment="PATH=/mnt/data/sb/projects/vllm-mistral/.venv/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
Environment="LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
Environment="TORCH_CUDA_ARCH_LIST=12.1a"
Environment="TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas"
Environment="FLASHINFER_JIT_LOG_LEVEL=ERROR"
Environment="TRANSFORMERS_VERBOSITY=error"
Environment="VLLM_SKIP_P2P_CHECK=1"
ExecStart=/mnt/data/sb/projects/vllm-mistral/.venv/bin/vllm serve mistralai/Mistral-Small-4-119B-2603-NVFP4 \
    --max-model-len 262144 \
    --tensor-parallel-size 1 \
    --attention-backend TRITON_MLA \
    --tool-call-parser mistral \
    --enable-auto-tool-choice \
    --reasoning-parser mistral \
    --max-num-batched-tokens 16384 \
    --max-num-seqs 128 \
    --gpu-memory-utilization 0.8 \
    --no-enable-flashinfer-autotune \
    --cudagraph-capture-sizes 1 2 4 8 16 32 64 128 256 \
    --max-cudagraph-capture-size 256
Restart=always
RestartSec=15
StandardOutput=journal
StandardError=journal
SyslogIdentifier=vllm
MemoryMax=120G
MemorySwapMax=0

[Install]
WantedBy=multi-user.target
sudo systemctl daemon-reload
sudo systemctl enable --now vllm

3.2 Vérification

# Logs en temps réel
journalctl -u vllm -f

# Vérifier que le modèle est bien servi
curl http://localhost:8000/v1/models

4. Explication des choix de configuration

--attention-backend TRITON_MLA

Le GB10 (SM121) n'est pas compatible avec FlashAttention ni FlashInfer MLA qui sont compilés pour SM100 (datacenter Blackwell). Sans cette option, vLLM crashe avec cudaErrorIllegalInstruction.

Triton recompile les kernels d'attention à la volée pour SM121, ce qui résout le problème.

--no-enable-flashinfer-autotune

C'est l'optimisation la plus impactante sur le temps de démarrage. Sans ce flag, l'autotuner FlashInfer teste des dizaines de tactiques MoE compilées pour SM120 (datacenter) qui échouent toutes sur SM121. Sur le 2ème démarrage, cela provoquait un délai de 1300 secondes (~22 minutes) supplémentaires.

Avec --no-enable-flashinfer-autotune, ce délai disparaît complètement :

Skipping FlashInfer autotune because it is disabled.

--cudagraph-capture-sizes 1 2 4 8 16 32 64 128 256

Par défaut, vLLM capture des CUDA graphs pour des dizaines de tailles de batch (jusqu'à 512 par paliers de 8). Pour un usage interactif avec un seul utilisateur, on n'a pas besoin de toutes ces tailles. Réduire à 9 tailles (puissances de 2) réduit drastiquement le temps de capture.

--max-cudagraph-capture-size 256

Limite la taille maximale de batch capturée. Cohérent avec --max-num-seqs 128.

--max-model-len 262144

Le modèle supporte jusqu'à 1M tokens de contexte, mais avec 128 GiB de mémoire unifiée dont ~66 GiB pour les poids, la mémoire KV cache disponible est limitée (~13 GiB). 262144 tokens (256K) est un bon compromis.

--gpu-memory-utilization 0.8

80% de la mémoire GPU est allouée pour le modèle + KV cache. Les 20% restants servent aux CUDA graphs et overheads runtime.

VLLM_SKIP_P2P_CHECK=1

Évite une vérification P2P qui peut prendre ~60 secondes sur certaines configurations à GPU unique.

TORCH_CUDA_ARCH_LIST=12.1a

Indique à PyTorch et Triton de compiler les kernels pour SM121 (la variante a désigne le GB10 consumer par opposition au B200/H100).

TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas

Pointe Triton vers le compilateur PTX de CUDA 13. Sans cela, Triton peut utiliser une version incompatible.


5. Performances de démarrage

Voici les temps mesurés sur un démarrage avec toutes les optimisations actives et le cache torch.compile chaud :

ÉtapeDurée
Chargement des poids (13 shards, SSD USB)606s
torch.compile (cache hit)4s
Warmup initial3s
CUDA graph capture (9 tailles)13s
init engine total104s
Démarrage total~12 min 27s
10:13:14 → démarrage du service
10:23:53 → Loading weights took 606.35 seconds
10:25:13 → torch.compile took 3.96 s (cache hit)
10:25:38 → Graph capturing finished in 13 secs
10:25:39 → init engine took 103.80 seconds
10:25:41 → Application startup complete

Le cache torch.compile est automatiquement stocké dans ~/.cache/vllm/torch_compile_cache/. Dès le 2ème démarrage, la compilation est quasi-instantanée (~4s au lieu de ~15s).

Le chargement des poids depuis un SSD USB (~475 MB/s) représente la majorité du temps de démarrage (~10 min) et est incompressible sans changer le stockage.


6. Test rapide

curl http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "mistralai/Mistral-Small-4-119B-2603-NVFP4",
    "messages": [{"role": "user", "content": "Explique ce qu'\''est un LLM en 3 phrases."}],
    "max_tokens": 256
  }'

7. Interface web avec Open WebUI

Pour une interface graphique accessible depuis un autre poste :

docker run -d \
  --network=host \
  --name open-webui \
  -e OPENAI_API_BASE_URL=http://localhost:8000/v1 \
  -e OPENAI_API_KEY=none \
  -v open-webui:/app/backend/data \
  --restart always \
  ghcr.io/open-webui/open-webui:main

Interface accessible sur http://<IP_DGX>:8080.


8. Problèmes connus

Warning PyTorch sur SM121

Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
Minimum and Maximum cuda capability supported by this version of PyTorch is (8.0) - (12.0)

Ce warning est inoffensif. PyTorch 2.10+cu130 a été compilé pour SM 8.0–12.0 officiellement, mais fonctionne correctement sur SM 12.1 grâce à la recompilation avec TORCH_CUDA_ARCH_LIST=12.1a.

Erreur repo_utils safetensors

ERROR: 'mistralai/Mistral-Small-4-119B-2603-NVFP4' is not a safetensors repo.

Ce message d'erreur est normal : ce modèle utilise le format consolidated.safetensors de Mistral et non le format HuggingFace standard. vLLM gère ça correctement en fallback.

Tensorizer incompatible

Tensorizer (qui permettrait de réduire le temps de chargement) est incompatible avec les modèles quantifiés en NVFP4 / compressed-tensors. Cette piste est à abandonner pour ce modèle.


Conclusion

Faire tourner un modèle de 119B paramètres localement sur un DGX Spark est tout à fait réalisable, avec quelques spécificités liées au GPU Blackwell consumer GB10. Les points clés :

Backend d'attention
TRITON_MLA
Désactiver l'autotuner
--no-enable-flashinfer-autotune
📊
Réduire les CUDA graphs
9 tailles (puissances de 2)

Le temps de démarrage incompressible reste dominé par le chargement des poids (~10 min sur SSD USB), mais une fois lancé, le serveur reste stable et performant pour un usage personnel ou en équipe.

Intéressé par un projet similaire ? Je suis disponible pour des missions de machine learning engineering — déploiement de LLMs, optimisation d'inférence, mise en production. N'hésitez pas à me contacter.