Matryoshka Representation Learning
MRL trains embeddings so any prefix of the vector is independently useful. How the multi-scale loss works, which models support it, and how to use it.
Most embedding models give you one vector and you use the whole thing. If the model produces 1536 dimensions, you store 1536 floats, compute similarity over all 1536, and pay for all 1536 in your vector index. The only way to get a smaller representation is to train a different model.
Matryoshka Representation Learning changes that. It trains a single model so that the first 64 dimensions, the first 256, the first 512, and the full dimension all independently carry useful semantic information. You get one model that behaves like several, and you decide at query time how much of the vector to use.
The name comes from Russian nesting dolls. The small doll lives inside the medium doll lives inside the large doll. Same idea: the small embedding lives inside the medium, lives inside the full.
The core problem MRL solves#
Standard embedding training optimizes for one thing: the quality of the full-dimensional vector. The structure of the intermediate dimensions is ignored. If you slice off the first 64 dimensions of a standard 768-dim embedding, you get something close to random. The model never learned to put anything useful there.
This creates a hard tradeoff. You want high retrieval quality, which means large embeddings. You want low storage and fast search, which means small embeddings. The only escape was to train separate models at each target dimension, which is expensive and annoying to maintain.
MRL breaks this tradeoff by restructuring the training objective itself.
How MRL training works#
You can apply a classification or contrastive loss at multiple embedding sizes simultaneously, and backpropagate all of them through a single shared backbone.
Given an input, the model produces a full-dimensional embedding z of size d (e.g., 1024). MRL evaluates that embedding at a set of nested sizes M = {8, 16, 32, 64, 128, 256, 512, 1024}. For each size m, it takes the first m dimensions z[1:m], applies a linear classification head W_m, and computes the task loss.
The total MRL loss is:
L_MRL = Σ_{m ∈ M} c_m · L(W_m · z[1:m], y)plaintextWhere:
z[1:m]is the prefix slice of the full embeddingW_mis a linear head specific to dimensionm(learned during training, discarded at inference)Lis your task loss: cross-entropy for classification, contrastive loss for retrievalc_mis a per-scale weight (usually uniform:c_m = 1for allm)
Every forward pass contributes gradients from every scale at once. The backbone learns to front-load information: the first few dimensions get pushed to capture the most discriminative signal possible, because they are the only dimensions evaluated at the smallest scales. Larger prefixes pick up progressively more detail.
The linear heads are auxiliary. They exist to create a gradient signal at each scale. You throw them away after training.
The loss in more detail#
For retrieval models, the task loss L is typically a contrastive loss like MultipleNegativesRankingLoss or InfoNCE:
L_contrastive = -log [ exp(sim(q, d+) / τ) / Σ_j exp(sim(q, dj) / τ) ]plaintextWhere q is the query embedding prefix, d+ is the positive document embedding prefix, dj are all documents in the batch (including negatives), and τ is a temperature.
MRL wraps this: for each scale m, compute the contrastive loss using only the first m dimensions of both query and document embeddings. Sum them up. The backbone sees gradients from all scales simultaneously.
The full vector ends up as the best possible representation at that size, and so does every prefix. The training objective makes it genuinely costly to bury signal deep in the vector where small-scale heads can’t see it.
Training MRL with sentence-transformers#
The sentence-transformers library has native MRL support through MatryoshkaLoss.
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from datasets import load_dataset
# Start from any pretrained model
model = SentenceTransformer("BAAI/bge-base-en-v1.5")
# Base loss applied at each matryoshka scale
base_loss = MultipleNegativesRankingLoss(model)
# Wrap with MatryoshkaLoss to apply at nested sizes
loss = MatryoshkaLoss(
model=model,
loss=base_loss,
matryoshka_dims=[768, 512, 256, 128, 64], # descending order
matryoshka_weights=[1, 1, 1, 1, 1], # uniform weighting
)
# Dataset: pairs of (query, positive_document)
dataset = load_dataset("sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", split="train")
args = SentenceTransformerTrainingArguments(
output_dir="mrl-bge-base",
num_train_epochs=1,
per_device_train_batch_size=64,
learning_rate=2e-5,
warmup_ratio=0.1,
)
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=dataset,
loss=loss,
)
trainer.train()
model.save_pretrained("mrl-bge-base-finetuned")pythonThe MatryoshkaLoss wrapper handles slicing z[1:m] and summing the scaled losses automatically. You write the base loss once and MRL does the rest.
Using MRL embeddings at inference#
After training, you embed your corpus once at full dimension and store the results. At query time, you can truncate to any size in your trained set.
import numpy as np
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
documents = [
"Matryoshka Representation Learning trains nested embeddings.",
"FAISS is a library for efficient similarity search.",
"Vector databases store embeddings and support ANN queries.",
]
query = "How do vector databases work?"
# Embed at full dimension (768 for this model)
doc_embeddings = model.encode(documents, convert_to_numpy=True)
query_embedding = model.encode(query, convert_to_numpy=True)
# Use only first 256 dims, still works well
dim = 256
doc_embs_truncated = doc_embeddings[:, :dim]
query_emb_truncated = query_embedding[:dim]
# L2 normalize before cosine similarity
doc_embs_norm = doc_embs_truncated / np.linalg.norm(doc_embs_truncated, axis=1, keepdims=True)
query_emb_norm = query_emb_truncated / np.linalg.norm(query_emb_truncated)
scores = doc_embs_norm @ query_emb_norm
ranked = np.argsort(scores)[::-1]
for i in ranked:
print(f"{scores[i]:.4f} {documents[i]}")pythonYou encoded once. The truncation happens in NumPy. No re-encoding, no second model.
MixedBread and Jina models#
MixedBread’s mxbai-embed-large-v1 and Jina’s jina-embeddings-v3 both ship MRL-trained. With sentence-transformers, truncation is a single constructor argument.
from sentence_transformers import SentenceTransformer
# MixedBread: 1024-dim model, truncate to 512
model = SentenceTransformer(
"mixedbread-ai/mxbai-embed-large-v1",
truncate_dim=512,
)
texts = ["What is retrieval augmented generation?"]
embeddings = model.encode(texts, convert_to_numpy=True)
print(embeddings.shape) # (1, 512)pythonJina’s v3 model supports even smaller sizes and works across multiple languages:
from sentence_transformers import SentenceTransformer
# Jina v3: 1024-dim, supports 32/64/128/256/512/1024
model = SentenceTransformer(
"jinaai/jina-embeddings-v3",
trust_remote_code=True,
truncate_dim=256,
)
texts = ["What is retrieval augmented generation?"]
embeddings = model.encode(texts, task="retrieval.query", convert_to_numpy=True)
print(embeddings.shape) # (1, 256)pythonThe truncate_dim parameter handles the slice and renormalization inside the model. You get the same result as slicing manually, but without having to remember to renormalize.
Quality at each scale#
MRL embeddings lose quality as you reduce dimensions, but the dropoff is gradual and often surprisingly small at moderate reductions.
import numpy as np
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import InformationRetrievalEvaluator
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
# Load your eval set: queries, corpus, relevant docs mapping
queries = {"q1": "What causes transformer attention to be expensive?"}
corpus = {
"d1": "Attention is O(n^2) in sequence length due to all-pairs token comparison.",
"d2": "Transformers use multi-head self-attention to process sequences in parallel.",
"d3": "The feed-forward sublayer in a transformer is applied independently per token.",
}
relevant = {"q1": {"d1"}}
dims_to_test = [64, 128, 256, 512, 768]
for dim in dims_to_test:
evaluator = InformationRetrievalEvaluator(
queries=queries,
corpus=corpus,
relevant_docs=relevant,
truncate_dim=dim, # sentence-transformers handles the slice
name=f"dim-{dim}",
)
results = evaluator(model)
ndcg = results[f"dim-{dim}_cosine_ndcg@10"]
print(f"dim={dim:4d} ndcg@10={ndcg:.4f}")pythonFor most datasets, cutting from 768 to 256 costs roughly 2-5% on NDCG@10. Cutting to 64 starts hurting, but it is still meaningfully better than a random projection at the same size.
Models that ship MRL by default#
Several widely used models now come MRL-optimized out of the box. No retraining needed.
| Model | Provider | Full dim | MRL dims available |
|---|---|---|---|
mxbai-embed-large-v1 | MixedBread | 1024 | 64, 128, 256, 512, 1024 |
jina-embeddings-v3 | Jina AI | 1024 | 32, 64, 128, 256, 512, 1024 |
nomic-embed-text-v1.5 | Nomic AI | 768 | 64, 128, 256, 512, 768 |
snowflake-arctic-embed-m-v1.5 | Snowflake | 768 | 256, 768 |
bge-m3 | BAAI | 1024 | Partial (FlagEmbedding) |
For any of these, you get the MRL property for free. Truncate the output to whatever your storage budget allows.
Adaptive retrieval#
MRL also enables two-stage retrieval with a single model. First pass uses small embeddings for fast ANN search over a large corpus. Second pass re-scores the top candidates with the full embedding.
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
# Simulate a corpus
corpus = [f"Document {i} about topic {i % 10}" for i in range(10_000)]
corpus_embs = model.encode(corpus, convert_to_numpy=True, show_progress_bar=True)
# Build two FAISS indices
SMALL_DIM = 64
FULL_DIM = 768
def build_index(embs: np.ndarray) -> faiss.Index:
embs_norm = embs / np.linalg.norm(embs, axis=1, keepdims=True)
index = faiss.IndexFlatIP(embs_norm.shape[1])
index.add(embs_norm.astype(np.float32))
return index
small_index = build_index(corpus_embs[:, :SMALL_DIM])
full_index = build_index(corpus_embs)
query = "information retrieval with dense embeddings"
query_emb = model.encode(query, convert_to_numpy=True)
# Stage 1: fast search with small embeddings
query_small = query_emb[:SMALL_DIM]
query_small /= np.linalg.norm(query_small)
_, candidates = small_index.search(query_small.reshape(1, -1).astype(np.float32), k=100)
candidates = candidates[0]
# Stage 2: re-rank candidates with full embeddings
candidate_embs = corpus_embs[candidates]
candidate_embs_norm = candidate_embs / np.linalg.norm(candidate_embs, axis=1, keepdims=True)
query_full = query_emb / np.linalg.norm(query_emb)
scores = candidate_embs_norm @ query_full
top5 = candidates[np.argsort(scores)[::-1][:5]]
for idx in top5:
print(corpus[idx])pythonStage 1 operates on 64-dim vectors, so your FAISS index is 12x smaller and search is correspondingly faster. Stage 2 re-scores 100 candidates with 768-dim vectors, which is cheap. The accuracy lands close to full-dimension search at a fraction of the index cost.
This is the pattern Vespa, Weaviate, and Qdrant have started building natively: store the full vector, index a compressed version, use the full vector for final scoring.
Where MRL actually helps#
MRL is not useful in every situation. Storage-constrained vector indexes are the clearest case. If you are indexing millions of documents and cannot afford a 1536-dim index, MRL lets you use 256 or 512 without swapping models. First-stage retrieval at scale is another: smaller vectors mean faster dot products in ANN search, and the speedup compounds as corpus size grows. It is also handy when sweeping dimension sizes for ablations, since you can test 64, 128, and 256 from a single model without retraining anything.
It does not help when your performance bottleneck is somewhere else entirely. Switching from 768 to 256 dimensions saves storage but will not fix a bad chunking scheme or a reranker that is not calibrated for your domain.
What MRL does not change#
The backbone architecture is unchanged. MRL is a training strategy, not a model architecture. You can apply it on top of any transformer encoder. The only structural addition is the per-scale linear heads during training, and those are discarded afterward.
Inference code does not change either. You call model.encode() exactly as before. The only difference is you can optionally slice the output.
There is one real cost worth knowing about. Optimizing for the 64-dim prefix means the model pushes the most discriminative signal into the first 64 dimensions. This can theoretically hurt full-dimension performance slightly compared to a model trained only for 768 dimensions. In practice the difference is small: the MRL loss is a sum, and the largest scale dominates because it receives gradients from the largest linear head. The original paper puts the full-dimension gap at 1-2% versus a dedicated single-scale model, which is usually worth the tradeoff.
A 4x reduction in embedding dimension translates directly to 4x less RAM in your vector index and 4x lower storage costs, all from the same model. That is a real savings, not a paper one. The fact that MixedBread, Jina, Nomic, and Snowflake all shipped MRL in their flagship models within roughly the same two-year window suggests the field decided this was worth the training overhead. If you are picking a new embedding model and MRL is not mentioned in the model card, it is worth finding out why.