vLLM logs tell you exactly what’s happening during model loading and inference—memory allocation, attention backends, CUDA graph capture, KV cache sizing. Understanding these logs helps you debug performance issues, optimize configurations, and reason about why your setup behaves the way it does.

This post walks through the startup logs from serving GPT-OSS-120B (a 117B parameter MoE model with MXFP4 quantization) on a single GPU via vLLM v0.11.2. Each log line is explained with links to source code and documentation. Logs are from this Kaggle notebook.

INFO 11-28 11:45:51 [scheduler.py:216] Chunked prefill is enabled with max_num_batched_tokens=2048.

max_num_batched_tokens caps tokens per scheduler step. With chunked prefill enabled (default in V1), the scheduler prioritizes decode requests and batches pending prefills into remaining token budget. Lower values (e.g., 2048) achieve better ITL (inter-token latency) because there are fewer prefills interrupting decodes. Higher values achieve better TTFT (time-to-first-token) as more prefill tokens are processed per batch. Default is 8192 for online serving, 16384 for offline. See vLLM Optimization docs.

This value also affects how much memory is left for KV cache. vLLM allocates all remaining GPU memory to KV cache after loading model weights—controlled by gpu_memory_utilization (default 0.9). Higher max_num_batched_tokens reserves more activation memory during the profiling step, leaving less for KV cache and reducing max_concurrency. Decreasing max_num_batched_tokens or max_num_seqs frees KV cache space for more concurrent requests.

(APIServer pid=90) INFO 11-28 11:45:51 [api_server.py:1977] vLLM API server version 0.11.2

(APIServer pid=90) INFO 11-28 11:45:51 [utils.py:253] non-default args: {'host': '0.0.0.0', 'model': '/kaggle/input/gpt-oss-120b/transformers/default/1', 'max_model_len': 98304, 'served_model_name': ['vllm-model'], 'gpu_memory_utilization': 0.96, 'max_num_seqs': 6}

max_num_seqs caps concurrent sequences per batch (default 1024 in V1, up from 256 in V0). Higher values allow more concurrent requests but require more KV cache space at runtime. Lower values reduce memory pressure and avoid preemption, where requests are evicted and recomputed when KV cache fills. Here, max_num_seqs=6 is set low for this memory-constrained single-GPU setup.

(APIServer pid=90) INFO 11-28 11:46:35 [model.py:631] Resolved architecture: GptOssForCausalLM

vLLM reads architectures from config.json and maps it to the corresponding model class in vllm/model_executor/models/.

(APIServer pid=90) ERROR 11-28 11:46:35 [config.py:307] Error retrieving safetensors: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '/kaggle/input/gpt-oss-120b/transformers/default/1'. Use repo_type argument if needed., retrying 1 of 2 (APIServer pid=90) ERROR 11-28 11:46:37 [config.py:305] Error retrieving safetensors: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '/kaggle/input/gpt-oss-120b/transformers/default/1'. Use repo_type argument if needed.

(APIServer pid=90) INFO 11-28 11:46:37 [model.py:1968] Downcasting torch.float32 to torch.bfloat16.

This is about the original checkpoint dtype. Since GPT-OSS uses quantization=mxfp4, weights end up as 4-bit anyway. bfloat16 is used for activations and KV cache.

(APIServer pid=90) INFO 11-28 11:46:37 [model.py:1745] Using max model len 98304

(APIServer pid=90) INFO 11-28 11:46:45 [scheduler.py:216] Chunked prefill is enabled with max_num_batched_tokens=8192.

Printed twice: first during APIServer config parsing (default 2048), second when EngineCore initializes with computed value (8192). 8192 is the online serving default.

(APIServer pid=90) INFO 11-28 11:46:45 [config.py:272] Overriding max cuda graph capture size to 1024 for performance.

CUDA graphs pre-record GPU operations to eliminate per-kernel CPU launch overhead. Each captured batch size requires memory to store the graph. Limiting to 1024 max balances memory usage against coverage of typical batch sizes. See vLLM CUDA Graphs docs.

(EngineCore_DP0 pid=289) INFO 11-28 11:47:23 [core.py:93] Initializing a V1 LLM engine (v0.11.2) with config: model='/kaggle/input/gpt-oss-120b/transformers/default/1', speculative_config=None, tokenizer='/kaggle/input/gpt-oss-120b/transformers/default/1', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=98304, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=mxfp4, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='openai_gptoss', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=vllm-model, enable_prefix_caching=True, enable_chunked_prefill=True, pooler_config=None, compilation_config={'level': None, 'mode': <CompilationMode.VLLM_COMPILE: 3>, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'inductor', 'custom_ops': ['none'], 'splitting_ops': ['vllm::unified_attention', 'vllm::unified_attention_with_output', 'vllm::unified_mla_attention', 'vllm::unified_mla_attention_with_output', 'vllm::mamba_mixer2', 'vllm::mamba_mixer', 'vllm::short_conv', 'vllm::linear_attention', 'vllm::plamo2_mamba_mixer', 'vllm::gdn_attention_core', 'vllm::kda_attention', 'vllm::sparse_attn_indexer'], 'compile_mm_encoder': False, 'use_inductor': None, 'compile_sizes': [], 'inductor_compile_config': {'enable_auto_functionalized_v2': False, 'combo_kernels': True, 'benchmark_combo_kernel': True}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.FULL_AND_PIECEWISE: (2, 1)>, 'cudagraph_num_of_warmups': 1, 'cudagraph_capture_sizes': [1, 2, 4, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496, 512, 528, 544, 560, 576, 592, 608, 624, 640, 656, 672, 688, 704, 720, 736, 752, 768, 784, 800, 816, 832, 848, 864, 880, 896, 912, 928, 944, 960, 976, 992, 1008, 1024], 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {}, 'max_cudagraph_capture_size': 1024, 'local_cache_dir': None}

Key settings for GPT-OSS:

  • quantization=mxfp4: MXFP4 (Microscaling FP4) compresses weights to 4-bit with per-group scaling factors
  • enable_prefix_caching=True: Automatic Prefix Caching shares KV cache for requests with common prefixes
  • cudagraph_mode=FULL_AND_PIECEWISE: FULL_AND_PIECEWISE uses full CUDA graphs for uniform decode batches, piecewise graphs for mixed prefill-decode batches
  • reasoning_parser='openai_gptoss': GPT-OSS specific reasoning output parser

(EngineCore_DP0 pid=289) INFO 11-28 11:47:31 [parallel_state.py:1208] world_size=1 rank=0 local_rank=0 distributed_init_method=tcp://172.19.2.2:54751 backend=nccl

Single GPU setup. world_size=1 means one process, backend=nccl uses NVIDIA NCCL (NVIDIA Collective Communications Library) for fast GPU-to-GPU communication in multi-GPU setups.

[W1128 11:47:31.402851082 socket.cpp:209] [c10d] The hostname of the client socket cannot be retrieved. err=-3 [Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0

Benign. With one GPU, there are no peers to connect to.

(EngineCore_DP0 pid=289) INFO 11-28 11:47:31 [parallel_state.py:1394] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0

Parallelism types (all 0 because single GPU). See How to Parallelize a Transformer from the JAX Scaling Book:

  • DP (Data Parallel): activations sharded along batch dimension, parameters replicated on each device, AllReduce gradients during backward pass
  • PP (Pipeline Parallel): layers distributed across devices, activations microbatched through pipeline stages
  • TP (Tensor Parallel): activations sharded along model dimension, parameters sharded along feed-forward dimension, AllGather/ReduceScatter between blocks
  • EP (Expert Parallel): for MoE models, distribute different experts across GPUs, tokens routed to GPUs holding selected experts

(EngineCore_DP0 pid=289) INFO 11-28 11:47:31 [gpu_model_runner.py:3259] Starting to load model /kaggle/input/gpt-oss-120b/transformers/default/1...

(EngineCore_DP0 pid=289) WARNING 11-28 11:47:32 [mxfp4.py:196] MXFP4 linear layer is not implemented - falling back to UnquantizedLinearMethod.

Some layers lack FP4 kernel implementations and run in bfloat16 instead. See vllm/model_executor/layers/quantization/mxfp4.py.

(EngineCore_DP0 pid=289) WARNING 11-28 11:47:32 [mxfp4.py:208] MXFP4 attention layer is not implemented. Skipping quantization for this layer.

Attention layers (Q, K, V projections) run unquantized. See the same mxfp4.py source.

(EngineCore_DP0 pid=289) INFO 11-28 11:47:32 [cuda.py:377] Using AttentionBackendEnum.TRITON_ATTN backend.

vLLM supports multiple attention backends:

Here, TRITON_ATTN is selected.

(EngineCore_DP0 pid=289) INFO 11-28 11:47:32 [layer.py:342] Enabled separate cuda stream for MoE shared_experts

GPT-OSS is a Mixture-of-Experts (MoE) model with shared experts (always active, capture common knowledge) and routed experts (selected by gating). Running shared experts on a separate CUDA stream allows overlapping their computation with the routed expert dispatch/combine operations. The expected overlap pattern is: shared experts compute in parallel with token dispatch to routed experts, then routed experts compute, then results combine. See vLLM RFC #9203 and vLLM MoE implementation.

(EngineCore_DP0 pid=289) INFO 11-28 11:47:32 [mxfp4.py:141] Using Marlin backend

Marlin is a weight-only quantization kernel from IST-DASLab. On GPUs without native FP4 tensor cores, Marlin decompresses FP4 weights on-the-fly during computation. See vLLM Marlin utils.

(EngineCore_DP0 pid=289) Loading safetensors checkpoint shards: 100% Completed 15/15 [09:05<00:00, 36.34s/it] (EngineCore_DP0 pid=289) INFO 11-28 11:57:30 [default_loader.py:314] Loading weights took 545.24 seconds

9 minutes is slow. Kaggle storage I/O is the bottleneck. Local NVMe + tensor parallelism would be faster.

(EngineCore_DP0 pid=289) WARNING 11-28 11:57:30 [marlin_utils_fp4.py:204] Your GPU does not have native support for FP4 computation but FP4 quantization is being used. Weight-only FP4 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.

This Kaggle environment’s GPU lacks native FP4 tensor cores. Native FP4 requires Blackwell architecture (SM100+)—Hopper (SM90) only supports FP8. The Marlin kernel provides weight-only compression: weights stored in FP4 (memory savings) but computed in higher precision. See GitHub issue #30135 for more on MXFP4 backend selection.

(EngineCore_DP0 pid=289) INFO 11-28 11:57:33 [gpu_model_runner.py:3338] Model loading took 65.9651 GiB memory and 600.408480 seconds

65.97 GiB is expected. Here’s the math based on GPT-OSS-120B config.json:

Architecture: 36 layers, hidden_size=2880, 128 experts per layer, intermediate_size=2880 per expert

MoE weights (~115B params, MXFP4 quantized):

  • Each expert (SwiGLU): 3 × hidden_size × intermediate_size = 3 × 2880 × 2880 = 24.9M params
  • Per layer: 128 experts × 24.9M = 3.19B params
  • 36 layers: 36 × 3.19B = 114.8B params
  • MXFP4 (4-bit + scaling): 114.8B × 0.5 bytes × 1.03 ≈ 59.1 GB

Non-MoE weights (~2B params, BF16):

  • Attention (Q/K/V/O), embeddings, layernorms, routers
  • ~2B × 2 bytes = ~4 GB

Total: 59.1 + 4 + ~2-3 GB (CUDA context, buffers) ≈ 65-66 GB

This matches Unsloth’s documentation which recommends “at least 66GB of unified memory” for gpt-oss-120B inference. See also OpenAI’s model card for architecture details.

(EngineCore_DP0 pid=289) INFO 11-28 11:57:52 [backends.py:631] Using cache directory: /root/.cache/vllm/torch_compile_cache/7fcbe477d2/rank_0_0/backbone for vLLM's torch.compile

(EngineCore_DP0 pid=289) INFO 11-28 11:57:52 [backends.py:647] Dynamo bytecode transform time: 19.07 s (EngineCore_DP0 pid=289) INFO 11-28 11:57:58 [backends.py:251] Cache the graph for dynamic shape for later use (EngineCore_DP0 pid=289) INFO 11-28 11:58:38 [backends.py:282] Compiling a graph for dynamic shape takes 44.90 s (EngineCore_DP0 pid=289) INFO 11-28 11:58:39 [monitor.py:34] torch.compile takes 63.97 s in total

torch.compile is PyTorch’s JIT compiler. TorchDynamo (the frontend) traces Python code and captures computation graphs. TorchInductor (the backend) then fuses multiple operations into optimized CUDA/Triton kernels—this reduces memory bandwidth usage since intermediate results stay in registers instead of being written to and read from global memory. See PyTorch’s torch.compile tutorial for details on how fusion works. vLLM adds custom fusion passes (RMSNorm+quantization, SiLU+quantization) for additional speedups. Compiled artifacts are cached in ~/.cache/vllm/torch_compile_cache for reuse across runs. See vLLM torch.compile docs.

(EngineCore_DP0 pid=289) INFO 11-28 11:58:41 [gpu_worker.py:359] Available KV cache memory: 8.99 GiB (EngineCore_DP0 pid=289) INFO 11-28 11:58:41 [kv_cache_utils.py:1229] GPU KV cache size: 130,960 tokens (EngineCore_DP0 pid=289) INFO 11-28 11:58:41 [kv_cache_utils.py:1234] Maximum concurrency for 98,304 tokens per request: 2.46x

The 2.46x multiplier is calculated in vllm/v1/core/kv_cache_utils.py. GPT-OSS uses an alternating attention pattern: odd layers use sliding window attention (sliding_window=128), even layers use full attention. This creates 2 KV cache groups.

Derivation (block_size=16, page_size=32,768 bytes for GQA with 8 KV heads × 64 head_dim × 2 bytes × 2 K+V):

  1. FullAttentionSpec memory per layer: ⌈98,304 / 16⌉ × 32,768 = 201.3 MB
  2. SlidingWindowSpec memory per layer: ⌈(128 − 1 + 8,192) / 16⌉ × 32,768 = 17.0 MB
    • Note: includes max_num_batched_tokens (8,192) for chunked prefill buffer
  3. Sum per layer: 201.3 + 17.0 = 218.4 MB
  4. Per request (18 layers per group): 18 × 218.4 MB = 3.93 GB
  5. Blocks per request: ⌈3.93 GB / (32,768 × 18)⌉ = 6,666 blocks
  6. Total blocks: 130,960 tokens × 2 groups / 16 block_size = 16,370 blocks
  7. max_concurrency: 16,370 / 6,666 = 2.46x

When KV cache fills: vLLM preempts lower-priority requests via recomputation (V1 default) or swap to CPU. This causes latency spikes for preempted requests.

(EngineCore_DP0 pid=289) 2025-11-28 11:58:41,951 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... (EngineCore_DP0 pid=289) 2025-11-28 11:58:41,973 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends (EngineCore_DP0 pid=289) Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 83/83

83 CUDA graphs captured (one per batch size in cudagraph_capture_sizes up to 1024). PIECEWISE mode splits the computation graph at attention operations—attention stays in eager mode while everything else (MLPs, norms) goes into CUDA graphs. This is necessary because attention has variable memory access patterns that are difficult to capture in static graphs.

Why CUDA graphs matter: each kernel launch has ~20μs of CPU overhead for setup and dispatch. For small batches where GPU compute time is short, this overhead dominates total latency. CUDA graphs pre-record the entire workflow once, then replay it with near-constant launch time (~2.5μs + ~1ns per node). See NVIDIA CUDA Programming Guide.

(EngineCore_DP0 pid=289) Capturing CUDA graphs (decode, FULL): 3/3 (EngineCore_DP0 pid=289) INFO 11-28 11:58:51 [gpu_model_runner.py:4244] Graph capturing finished in 9 secs, took 0.64 GiB

3 full graphs for decode-only batches (simpler access pattern allows capturing entire forward pass). Uses 0.64 GiB memory.

(EngineCore_DP0 pid=289) INFO 11-28 11:58:51 [core.py:250] init engine (profile, create kv cache, warmup model) took 78.26 seconds

(APIServer pid=90) INFO 11-28 11:58:53 [api_server.py:1725] Supported tasks: ['generate']

(APIServer pid=90) WARNING 11-28 11:58:53 [serving_responses.py:175] For gpt-oss, we ignore --enable-auto-tool-choice and always enable tool use.

(APIServer pid=90) INFO 11-28 11:58:53 [api_server.py:2052] Starting vLLM API server 0 on http://0.0.0.0:8000

(APIServer pid=90) INFO 11-28 11:58:53 [launcher.py:38] Available routes are:

(APIServer pid=90) INFO 11-28 11:58:53 [launcher.py:46] Route: /docs, Methods: GET, HEAD

(APIServer pid=90) INFO 11-28 11:58:53 [launcher.py:46] Route: /tokenize, Methods: POST

(APIServer pid=90) INFO 11-28 11:58:53 [launcher.py:46] Route: /detokenize, Methods: POST

(APIServer pid=90) INFO 11-28 11:58:53 [launcher.py:46] Route: /v1/models, Methods: GET

(APIServer pid=90) INFO 11-28 11:58:53 [launcher.py:46] Route: /v1/chat/completions, Methods: POST

(APIServer pid=90) INFO 11-28 11:58:53 [launcher.py:46] Route: /v1/completions, Methods: POST

(APIServer pid=90) INFO 11-28 11:58:53 [launcher.py:46] Route: /metrics, Methods: GET

vLLM exposes an OpenAI-compatible API:

  • /v1/chat/completions - Chat completions (conversation format with messages array)
  • /v1/completions - Text completions (raw prompt format)
  • /v1/models - List available models and metadata
  • /tokenize, /detokenize - Encode text to token IDs and decode back
  • /metrics - Prometheus metrics for monitoring (vllm:e2e_request_latency_seconds, vllm:num_requests_running, vllm:kv_cache_usage_perc)
  • /docs - Swagger/OpenAPI interactive documentation

(APIServer pid=90) INFO: Started server process [90] (APIServer pid=90) INFO: Waiting for application startup. (APIServer pid=90) INFO: Application startup complete. (APIServer pid=90) INFO: 127.0.0.1:55440 - "GET /v1/models HTTP/1.1" 200 OK

(APIServer pid=90) INFO: 127.0.0.1:55440 - "POST /v1/chat/completions HTTP/1.1" 200 OK

(APIServer pid=90) INFO: 127.0.0.1:55484 - "POST /v1/chat/completions HTTP/1.1" 200 OK

(APIServer pid=90) INFO 11-28 11:59:14 [loggers.py:236] Engine 000: Avg generation throughput: 544.8 tokens/s, GPU KV cache usage: 4.3%, Prefix cache hit rate: 81.6%

(APIServer pid=90) INFO 11-28 11:59:44 [loggers.py:236] Engine 000: Avg generation throughput: 514.2 tokens/s, GPU KV cache usage: 10.4%, Prefix cache hit rate: 81.6%

(APIServer pid=90) INFO 11-28 12:00:44 [loggers.py:236] Engine 000: Avg generation throughput: 468.0 tokens/s, GPU KV cache usage: 21.5%, Prefix cache hit rate: 81.6%

(APIServer pid=90) INFO 11-28 12:01:44 [loggers.py:236] Engine 000: Avg generation throughput: 426.6 tokens/s, GPU KV cache usage: 31.6%, Prefix cache hit rate: 81.6%

(APIServer pid=90) INFO 11-28 12:02:44 [loggers.py:236] Engine 000: Avg generation throughput: 396.0 tokens/s, GPU KV cache usage: 41.0%, Prefix cache hit rate: 81.6%

(APIServer pid=90) INFO 11-28 12:03:54 [loggers.py:236] Engine 000: Avg generation throughput: 367.2 tokens/s, GPU KV cache usage: 51.1%, Prefix cache hit rate: 81.6%

(APIServer pid=90) INFO 11-28 12:04:54 [loggers.py:236] Engine 000: Avg generation throughput: 348.6 tokens/s, GPU KV cache usage: 59.3%, Prefix cache hit rate: 81.6%

Throughput drops as KV cache fills:

  • 4.3% cache → 544.8 tok/s
  • 21.5% cache → 468.0 tok/s
  • 41.0% cache → 396.0 tok/s
  • 59.3% cache → 348.6 tok/s

This demonstrates that LLM decoding is memory-bandwidth bound. During each decode step, the GPU must load the entire KV cache from HBM (high-bandwidth memory) to compute attention over all previous tokens. As sequences grow longer, the amount of data transferred per token increases linearly, but the compute per token stays constant (one output token). According to NVIDIA’s LLM optimization guide, “the speed at which the data (weights, keys, values, activations) is transferred to the GPU from memory dominates the latency, not how fast the computation actually happens.”

With 6 concurrent requests generating long outputs, each request’s growing KV cache competes for memory bandwidth. The vLLM architecture blog notes that “decode requests are memory-bandwidth-bound since we still need to load all LLM weights (and KV caches) just to compute one token.”