@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@Ñ@@@Ñ@@@@@@@@@@@@@@@@@@@##@@@@@@@@@@@@@@@@@@Ñ@@@@
@@#@@@@@@@@@#@@@#@#@@@@@@@@@@@@#@##@#@#@@@#@@@@@@@@@@@@@#@#@@@@@
@@@#@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@#@@#@#@@@@@@@#@@@#@@@@@@@
@@@@@#@@@@@@@@@#@@@@@@#@@@@@@#@@@@@@@@@@@@@@#@@#@@#@@#@@#@@#@@@@
@@@@@@@@#@#@@@@@@@#@@@@@@@#@@@@@@@7-0@@@@@#@@@@@@@#@@@@@@@#@@@@@
##@##@#@###@@@#@@@###@#@###@@@@@96?--#@@@@@@@@@@@@@##@@@@@#@@@@@
#@#@@@@@@@#@@@@@@@@@@@@@#@#@#@@7562?+a#@@@@@@@@@@@@@@@@@@@#@@@@@
#@#@@@@@#@#@@@@@@@#@@@@@#@@@@86!abb;-,.:@@@@@@@@@@@@@@@@@@#@@@@@
@@@@@@@@@@@@@@@@@@@@@@#@@@@@@#52??b:--.__#@@@@@@@@@@@@@@@@@@@@@@
#@#@@@@@#@@@@@@@@@#@@@@@@@@@@W72!!c;c++, _Ñ@@@@@@@@@@@@@@@@@@@@@
#@#@@@@@#@@@@@@@@@@@@@@@#@@@@$620b;c!b2=, _@#@@@@@@@@@@@@@@@@@@@
#@#@@@@@#@@@@@@@@@@@@@@@#@@@@97b31;bbcbcb-_ .#@@@@@@@@@@@@#@@@@@
@@#@@@@@#@#@@@@@#@#@@@@@#@@@@9774b?3!?!3b++_  6@@@#@@@@@#@#@@@@@
@@###@@@#@@@@@@@#@@@@@@@@@@@@#77765895!0!!c+,,=1#@@@@@@@@@#@@@@@
#####@@@#@#@@@@@@@@@@@@@#@#@@@88$$$W#933404?=--@@@@@#@@@@@#@@@@@
@@#@@@@@#@#@@@@@@@#@@@@@#@@@#@731a279W713@0:,-+@@@@@@@@@#@@@@@@@
#@#@@@@@@@#@@@@@@@#@@@@@@@@@@@$54?c258379762479#@@@@@@@@#@#@@@@@
@@@@@@@@#@@@@@@@@@@@@@@@@@@@@@@747c;27a98@9#@@@@@@@@@@@@@@@@@@@@
#@#@@@@@#@@@@@@@@@@@@@@@#@@@@@@746c;c;!5553#@@@@@@@@@@@@#@@@@@@@
@@@@@@@@#@#@@@@@@@@@@@@@@@@@@@@887?!-==+cb+9@@@@@@@@@@@@#@@@@@@@
WWWWWWWWWWWWWWWWWWWWW######@@@@9872:+-._....@@@@@@#@@@@@#@@@@@@@
$$$$$$$$99999999999999$$$$$$$$W977?;:=-_    _@@@@@#@@@@@@@#@@@@@
8888877777777777777777777778888577cb;:+,     ?####@@@###@@@@@@@@
776666665555555555555555555555677??abbc+,_  _=$9$$WWW##@@@#@@@@@
5555544443344333333333333333357711!aaab;=,.,=;667788899$$WW###@@
443333333877776311111111111167442??aaabc;++:;2444555667778899$$$
2322222167777777766640??000088831?!!bbbbc;cc!1222223344455566778
21111100026666666666665?!!!!!0775320?!????122?000000112223334455
100000??????035555555555551aaa778898777775a-aa!!!!????0001112223
000?????!!!!!aaaa0444455454457861?bc:+=-,,-+:bbbbaaaa!!!!????001
???!?!!!!!aaaaabbbbb444444450?!abccc:+-,,,--,,=ccbbbbbaaaa!!!???
???!!!!!aaaaaabbbbbcb!443334?b;:++=--,,..__    ;;;cccccbbbbbaaa!
??!!!!aaaaabbbbbccc;c;a233373?ac:=-,.        .;;;;;;;;;;;ccbbba
?!!!!aaaaabbbbbcccc;;;;;3358852?b;+=-,__   _,+?+:::::;;;;;;;;ccc
!!!!aaaaabbbbbcccc;;;;;;;b337987420!abbc;cb!15c.+:::::::::;;;;;;
!!!!aaaaabbbbbcccc;;;;;;;;:2!cba0467765542b:, _c+++++++::::::;;;
!!!!aaaaaabbbbcccc;;;;;;;;:cbaa!a;==,..___,=cbb:++++++++::::::::
!!!!!aaaaabbbbbcccc;;;;;;;;bbcbbbbaabbbbbcbbccc+++++++++++++::::
!!!!!aaaaaabbbbccccc;;;;;;;;bbbbbbbbbbbbcccccc:+++++++++++++::::
??!!!aaaaaaabbbbbccccc;;;;;;;;bbbbbbbcccccccc++++++++++++++++:::
???!!!!aaaaaabbbbbbcccc;;;;;ccbbbbccccccccc;+:+++++++++++++:+:::
?????!!!aaaaaaaabbbbbccccc;;cccbbbbcccccccc;;::::::++:::::::::::
0?????!!!!!!aaaaabbbbbcccccccccccccccccccc;;;:::::::::::::::::::
00?????!!!!!aaaaaaaabbbbbcccccbbbbbbbbccccc;;;;;::::::::::::;:;;
00000?????!!!!aaaaaaaabbbbbbbbbbbbbbbbcbccc;c;;;;;;;;;;;;;;;;;;;
100000??????!!!!!aaaaaaaabbbbbbbbbbbbbbbbbccc;;;;;;;;;;;;;;;;;;;
Published on

Learning Latent Embeddings

Authors

Contrastive Learning

Contrastive Learning is the most commonly used technique to train large embedding models with transformers.

The goal of contrastive learning is to learn dense vector representations of the input modality (e.g. images, text, video) where distances between similar items are close and dissimilar items are far apart.

Siamese Networks

First introduced in Bromsky et al. as a way to perform signature verification, the Siamese Network is a neural network architecture that makes use of two towers or encoders to learn representations between similar and dissimilar items.

Siamese Network

It does this by minimizing the objective f(xi)f(xj)2||f(x^i) - f(x^j)||^2 when the input images xx are the same and maximizing the objective when they are different.

Contrastive Loss

We can formulate a loss objective where we define the similarity as the dot product of the L2-normalized feature representations.

sim(x,y)=f(x)f(x)2f(y)f(y)2sim(x, y) = \frac{f(x)}{|| f(x)||_2} \cdot \frac{f(y)}{||f(y)||_2}

Given that x is our query or x+x^+ and xx^- is our in-batch positives and negatives.

Our loss is defined as:

l(x,x+)=logexp(sim(x,x+)/τ)exp(sim(x,x+)/τ)+j=1Nexp(sim(x,xj)/τ)l(x, x^+) = - \log \frac{\exp (sim(x, x^+ ) / \tau)}{ \exp(sim(x, x^+)/\tau) + \sum^N_{j=1} \exp(sim(x, x_j^-)/ \tau)}

One of the first papers to explore this idea of contrastive learning for image embeddings was Momentum Encoder or MoCo (He et al. 2019)

The key takeaway from MoCo was that it makes use of a dummy encoder from which resolves the computational cost of training a unique encoder without losing quality of negative encodings over time.

wkmwk+(1m)wqw_k \leftarrow m w_k + (1 - m) w_q

Momentum Encoder

Triplet Loss

Learning triplets

Triplet loss can be viewed as a better contrastive loss. Given an anchor AA, a positive sample PP, and a negative sample NN, Triplet seeks to minimize the objective

l(A,P,N)=max((f(A)f(P))2(f(N)f(A))2+α,0)l(A,P,N) = \max((|f(A)-f(P)|)^2 - (|f(N)-f(A)|)^2 + \alpha, 0)

where each absolute value term represents the distance between the encodings learned of anchor A and the negative and positive samples.

We want to make sure (f(A)f(P))2(f(N)f(A))2(|f(A)-f(P)|)^2 - (|f(N)-f(A)|)^2 or the distance between the anchor and positive term is less than the distance between the anchor and the negative term. Additionally, sometimes we have threshhold α\alpha to ensure that the model learns that positives and anchors are much more similar than negatives and anchors.

Relationship with Language Modeling

Computing a vector similarity between latent vectors to extract meaning behind the data is reminiscent to the core principle behind self-attention in language models such as GPT (Generative Pre-Trained Transformer)

From the diagram below, a transformer learns through self-attention the similarity level between different words or tokens in the input dictionary.

Transformer as an Inquiry System

When we generate new words, the model selects the word with the highest probability in relation to the query and keys computed by a weighted similarity or attention function softmax(QKTdk)Vsoftmax(\frac{QK^T}{\sqrt{d_k}})V.

More specifically, this computation is represented by a lookup table defined by learnable parameters from our word embedding xix_i and our encoder networks QQ, KK, and VV.

qi=Qxi (queries) ki=Kxi (keys) vi=Vxi (values) q_i = Qx_i \text{ (queries) } \quad k_i = Kx_i \text{ (keys) } \quad v_i = Vx_i \text{ (values)}

Lookup Table for Queries

High-level overview

At a high level, a transformer is learning an implicit probability transition matrix based off of predicting the next term from the context. This can be formalized in multi-headed attention like so:

p(xtxt1...x1)=softmax(FFNn(max(0,(Norm(Concat(h1,h2,...,hk)W0)))))p(x_t | x_{t-1} ... x_{1}) = softmax(FFN_n(max(0, (Norm(Concat(h_1, h_2, ... , h_k) W^0 )))))

where an attention head hih_i is computing the function

hi=softmax(QKTdk)Vh_i = softmax(\frac{QK^T}{\sqrt{d_k}})V

with a multi-layer feed forward network defined as:

FFNi(x)=max(0,Wi1(FFNi1(x)+b))Wi+bFFN_i(x) = max(0, W_{i-1}(FFN_{i-1}(x) + b))W_i + b

Dimensionality Reduction

Understanding the nature of the high-dimensional representations your model is learning can seem overly abstract especially when working with vectors of size 1000 and greater.

Fortunately, there exist dimensionality reduction techniques to visualize and understand these embeddings. One commonly used technique for dimensionality reduction of embeddings is UMAP (Uniform Manifold Approximation and Projection).

The core idea of UMAP is to find a global topological structure of the data. More on this here.

Below is an example of how UMAP can represent text embeddings of different internet content you might see in your "For You Page" in 3D.

Interactive graph of content embeddings with UMAP (Use arrow keys to explore)

Conclusion

Thanks for reading up to this point!

Hopefully, you learned a little about representation learning and gained some intuition on how DL models learn semantic relationships through encoding latents.

Acknowledgements

Thank you to Andrew Zhang for feedback on earlier drafts!

Resources