Word2Vec with TensorFlow

Word2Vec with Skip-Gram and TensorFlow

This is a tutorial and a basic example for getting started with word2vec model by Mikolov et al. It is used for learning vector representations of words, called "Words Embeddings". For more information about Embeddings, read my previous post.

The word2vec model can be trained with two different word representations:

  • Continuous Bag-of-Words (CBOW): predicts target words (e.g. 'mat') from source context words ('the cat sits on the')
  • Skip-Gram: predicts source context-words from the target words

Skip-Gram tends to do better and this tutorial will implement a word2vec with skip-grams.

The goal of the model is to train it's embeddings layer in a way that similar by meaning words are close to each other in their N-dimensional vector representation. The model has two layers: the embeddings layer and a linear layer. Because of the last layer is linear, the distance between embedding vectors for words is linearly related to the distance in the meaning of those words. In other words, we are able to do such mathematical operations with the vectors: [king] - [man] + [woman] ~= [queen]

In [1]:
%env CUDA_VISIBLE_DEVICES=0
import time
import numpy as np
import pandas as pd
import tensorflow as tf
import sklearn
import nltk
env: CUDA_VISIBLE_DEVICES=0

Dataset

To train a word2vec model, we need a large text corpus. This example uses the text from the "20 newsgroups dataset". The dataset contains 11314 messages form a message board with corresponding labels for its topics. We just merge all messages together and ignore the labels. In practice, it's better to use a larger corpus and to have a domain-specific text. Lowering the case of the text is optional and recommended when working with a small corpus.

In [2]:
from sklearn.datasets import fetch_20newsgroups
data = fetch_20newsgroups()

text = ' '.join(data.data).lower()
text[100:350]
Out[2]:
'umd.edu\norganization: university of maryland, college park\nlines: 15\n\n i was wondering if anyone out there could enlighten me on this car i saw\nthe other day. it was a 2-door sports car, looked to be from the late 60s/\nearly 70s. it was called a bric'

Sentence Tokenize

The skip grams will work better if they are created from sentenced text. nltk.sent_tokenize will break a string to a list of sentences.

In [3]:
sentences_text = nltk.sent_tokenize(text)
len(sentences_text)
Out[3]:
173268

Word Tokenize

Next, break all sentences to tokens (words) with nltk.word_tokenize.

In [4]:
sentences = [nltk.word_tokenize(s) for s in sentences_text]
print(sentences[10])
['please', 'send', 'a', 'brief', 'message', 'detailing', 'your', 'experiences', 'with', 'the', 'procedure', '.']

Vocabulary (unique words)

In this example, we filter words who are used less than 5 times in the text, stop words and punctuations.

In [5]:
from collections import  Counter
from string import punctuation
from nltk.corpus import stopwords

min_count = 5
puncs = set(punctuation)
stops = set(stopwords.words('english'))

flat_words = []
for sentence in sentences:
    flat_words += sentence
    
counts = Counter(list(flat_words))
counts = pd.DataFrame(counts.most_common())
counts.columns = ['word', 'count']

counts = counts[counts['count'] >= min_count]
counts = counts[~counts['word'].isin(puncs)]
counts = counts[~counts['word'].isin(stops)]


vocab = pd.Series(range(len(counts)), index=counts['word']).sort_index()

print('The vocabulary has:', len(vocab), 'words')
The vocabulary has: 34016 words

Filter tokens not in vocabulary

Some words were excluded from the vocabulary because they are very rare or too common to present value. We have to remove them from our sentences.

In [6]:
filtered_sentences = []

for sentence in sentences:
    sentence = [word for word in sentence if word in vocab.index]
    if len(sentence):
        filtered_sentences.append(sentence)
sentences = filtered_sentences

Transform the words to integer indexes

In [7]:
for i, sentence in enumerate(sentences):
    sentences[i] = [vocab.loc[word] for word in sentence]

Create Skip-Gram dataset

In [8]:
from nltk.util import skipgrams

window_size = 10

data = []
for sentance in sentences:
    data += skipgrams(sentance, 2, window_size)

data = pd.DataFrame(data, columns=['x', 'y'])
data.head()
Out[8]:
x y
0 5816 4
1 5816 122
2 5816 6
3 5816 159
4 4 122

Train and Validation Split

In [9]:
validation_size = 5000

data_valid = data.iloc[-validation_size:]
data_train = data.iloc[:-validation_size]
print('Train size:', len(data_train), 'Validation size:', len(data_valid))
Train size: 14098409 Validation size: 5000

Model Hyperparameters

In [10]:
learning_rate = .01
embed_size = 300
batch_size = 64
steps = 1000000

Model Inputs

In [11]:
inputs = tf.placeholder(tf.int32, [None])
targets = tf.placeholder(tf.int32, [None])

Embeddings Layer

This is the embeddings layer. Its a len(vocab) by embed_size matrix, initialized with random uniform distribution. The optimizer will change the similarity between it's rows to be higher on similar words.

In [12]:
embeddings = tf.Variable(tf.random_uniform((len(vocab), embed_size), -1, 1))
embed = tf.nn.embedding_lookup(embeddings, inputs)

Linear layer

We use a linear layer with activation=None. We don't need this layer after the training. Think of it as part of the loss function.

In [13]:
logits = tf.layers.dense(embed, len(vocab), activation=None,
    kernel_initializer=tf.random_normal_initializer())

Loss & Optimization

There is a more optimized, noise-contrastive loss function for traning word embeddings: tf.nn.nce_loss. I use tf.nn.softmax_cross_entropy_with_logits for simplicity. For more information about the nce_loss look at the TensorFlow word2vec tutorial.

In [14]:
labels = tf.one_hot(targets, len(vocab))
loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)
loss = tf.reduce_mean(loss)

train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)

Start Session

In [15]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

Training Loop

In [16]:
from sklearn.metrics.pairwise import cosine_similarity

def get_batches(x, y, batch_size, n=None):
    if n:
        # cheap way to add some randomization
        rand_start = np.random.randint(0, len(x) - batch_size * n)
        x = x[rand_start:]
        y = y[rand_start:]

    for start in range(len(x))[::batch_size][:n]:
        end = start + batch_size
        yield x[start:end], y[start:end]

step = 0
while step < steps:
    start = time.time()
    
    # shuffle train data once in while
    if step % 100000 == 0:
        data_train = data_train.sample(frac=1.)
    
    # train part
    train_loss = []
    for x, y in get_batches(
        data_train['x'].values, data_train['x'].values, batch_size, n=10000):
        step += 1
        _, batch_loss = sess.run([train_op, loss], {inputs: x, targets: y})
        train_loss.append(batch_loss)

    # validation prat (one batch of "validation_size")
    feed_dict = {inputs: data_valid['x'].values, targets: data_valid['x'].values}
    valid_loss, x_vectors = sess.run([loss, embed], feed_dict)
    y_vectors = sess.run(embed, {inputs: data_valid['x'].values})

    # outputs
    print('Step:', step, 'TLoss:', np.mean(train_loss), 'VLoss:', np.mean(valid_loss),
          'Similarity: %.3f' % cosine_similarity(x_vectors, y_vectors).mean(),
          'Seconds %.1f' % (time.time() - start))
Step: 10000 TLoss: 2.0164194 VLoss: 0.5388802 Similarity: 0.023 Seconds 66.5
Step: 20000 TLoss: 0.12752666 VLoss: 0.11306174 Similarity: 0.023 Seconds 65.2
Step: 30000 TLoss: 0.028313937 VLoss: 0.11745877 Similarity: 0.023 Seconds 65.5
Step: 40000 TLoss: 0.010771247 VLoss: 0.012606331 Similarity: 0.024 Seconds 65.8
Step: 50000 TLoss: 0.0013680928 VLoss: 0.012557062 Similarity: 0.024 Seconds 65.9
Step: 60000 TLoss: 0.0041248337 VLoss: 7.6293943e-10 Similarity: 0.025 Seconds 65.8
Step: 70000 TLoss: 0.0014730034 VLoss: 1.0251998e-09 Similarity: 0.026 Seconds 65.9
Step: 80000 TLoss: 0.00054400414 VLoss: 6.67572e-10 Similarity: 0.027 Seconds 65.9
Step: 90000 TLoss: 0.0005858248 VLoss: 1.3113021e-09 Similarity: 0.027 Seconds 65.9
Step: 100000 TLoss: 0.0009175814 VLoss: 8.1062307e-10 Similarity: 0.028 Seconds 65.9
Step: 110000 TLoss: 0.00022360244 VLoss: 9.059905e-10 Similarity: 0.029 Seconds 66.9
Step: 120000 TLoss: 6.2553576e-05 VLoss: 6.4373007e-10 Similarity: 0.030 Seconds 65.8
Step: 130000 TLoss: 0.00019555185 VLoss: 2.4318687e-09 Similarity: 0.030 Seconds 65.8
Step: 140000 TLoss: 5.0704068e-05 VLoss: 9.775161e-10 Similarity: 0.031 Seconds 65.8
Step: 150000 TLoss: 0.00014003972 VLoss: 1.0251997e-09 Similarity: 0.032 Seconds 65.9
Step: 160000 TLoss: 7.326659e-05 VLoss: 1.2636184e-09 Similarity: 0.033 Seconds 65.9
Step: 170000 TLoss: 0.00014974378 VLoss: 9.775161e-10 Similarity: 0.033 Seconds 65.8
Step: 180000 TLoss: 1.296028e-09 VLoss: 1.7404554e-09 Similarity: 0.034 Seconds 65.8
Step: 190000 TLoss: 9.8594224e-05 VLoss: 8.583068e-10 Similarity: 0.035 Seconds 65.8
Step: 200000 TLoss: 5.5742334e-05 VLoss: 1.3113021e-09 Similarity: 0.036 Seconds 65.8
Step: 210000 TLoss: 1.2936069e-09 VLoss: 9.536743e-10 Similarity: 0.037 Seconds 66.9
Step: 220000 TLoss: 1.3055278e-09 VLoss: 1.4066694e-09 Similarity: 0.037 Seconds 65.8
Step: 230000 TLoss: 0.0001788908 VLoss: 1.0490416e-09 Similarity: 0.038 Seconds 65.8
Step: 240000 TLoss: 5.5293618e-05 VLoss: 2.4795528e-09 Similarity: 0.039 Seconds 65.7
Step: 250000 TLoss: 1.3612207e-09 VLoss: 1.5735624e-09 Similarity: 0.040 Seconds 65.8
Step: 260000 TLoss: 1.3375653e-09 VLoss: 1.0728836e-09 Similarity: 0.041 Seconds 65.8
Step: 270000 TLoss: 1.3418494e-09 VLoss: 1.5258788e-09 Similarity: 0.042 Seconds 65.8
Step: 280000 TLoss: 1.3548879e-09 VLoss: 1.001358e-09 Similarity: 0.043 Seconds 65.7
Step: 290000 TLoss: 1.3586132e-09 VLoss: 1.2159346e-09 Similarity: 0.045 Seconds 65.8
Step: 300000 TLoss: 0.0001109721 VLoss: 9.298324e-10 Similarity: 0.046 Seconds 65.8
Step: 310000 TLoss: 1.3770534e-09 VLoss: 7.867812e-10 Similarity: 0.047 Seconds 66.9
Step: 320000 TLoss: 1.3697891e-09 VLoss: 8.1062307e-10 Similarity: 0.048 Seconds 65.7
Step: 330000 TLoss: 1.3677401e-09 VLoss: 8.8214863e-10 Similarity: 0.050 Seconds 65.8
Step: 340000 TLoss: 1.3770534e-09 VLoss: 1.0490416e-09 Similarity: 0.051 Seconds 65.7
Step: 350000 TLoss: 1.3750044e-09 VLoss: 9.775161e-10 Similarity: 0.052 Seconds 65.8
Step: 360000 TLoss: 1.3889743e-09 VLoss: 1.2159347e-09 Similarity: 0.054 Seconds 65.8
Step: 370000 TLoss: 1.3785435e-09 VLoss: 9.298324e-10 Similarity: 0.055 Seconds 65.8
Step: 380000 TLoss: 1.3930721e-09 VLoss: 9.536743e-10 Similarity: 0.057 Seconds 65.8
Step: 390000 TLoss: 1.3809649e-09 VLoss: 1.2397765e-09 Similarity: 0.059 Seconds 65.8
Step: 400000 TLoss: 1.3953073e-09 VLoss: 1.1205672e-09 Similarity: 0.060 Seconds 65.8
Step: 410000 TLoss: 1.3913958e-09 VLoss: 1.3589857e-09 Similarity: 0.062 Seconds 67.0
Step: 420000 TLoss: 1.4010815e-09 VLoss: 1.4066694e-09 Similarity: 0.064 Seconds 65.7
Step: 430000 TLoss: 1.400709e-09 VLoss: 1.2874601e-09 Similarity: 0.066 Seconds 65.8
Step: 440000 TLoss: 1.4089047e-09 VLoss: 1.5258788e-09 Similarity: 0.068 Seconds 65.8
Step: 450000 TLoss: 1.407787e-09 VLoss: 1.9550321e-09 Similarity: 0.070 Seconds 65.8
Step: 460000 TLoss: 1.4074145e-09 VLoss: 1.2159347e-09 Similarity: 0.072 Seconds 65.8
Step: 470000 TLoss: 1.4100222e-09 VLoss: 1.2397764e-09 Similarity: 0.074 Seconds 65.9
Step: 480000 TLoss: 1.4215706e-09 VLoss: 9.536743e-10 Similarity: 0.076 Seconds 65.8
Step: 490000 TLoss: 1.4159827e-09 VLoss: 9.298324e-10 Similarity: 0.079 Seconds 65.8
Step: 500000 TLoss: 1.416169e-09 VLoss: 9.298324e-10 Similarity: 0.081 Seconds 65.8
Step: 510000 TLoss: 1.4213843e-09 VLoss: 9.298324e-10 Similarity: 0.083 Seconds 66.9
Step: 520000 TLoss: 1.4251096e-09 VLoss: 1.0967254e-09 Similarity: 0.086 Seconds 65.8
Step: 530000 TLoss: 1.4256684e-09 VLoss: 1.0251998e-09 Similarity: 0.089 Seconds 65.8
Step: 540000 TLoss: 1.4100222e-09 VLoss: 9.536743e-10 Similarity: 0.092 Seconds 65.8
Step: 550000 TLoss: 6.02954e-05 VLoss: 1.3828276e-09 Similarity: 0.095 Seconds 65.8
Step: 560000 TLoss: 1.4210118e-09 VLoss: 9.298324e-10 Similarity: 0.097 Seconds 65.8
Step: 570000 TLoss: 1.4185904e-09 VLoss: 1.0967254e-09 Similarity: 0.100 Seconds 65.8
Step: 580000 TLoss: 1.4187767e-09 VLoss: 1.7404554e-09 Similarity: 0.104 Seconds 65.9
Step: 590000 TLoss: 1.4262272e-09 VLoss: 1.1444091e-09 Similarity: 0.107 Seconds 65.8
Step: 600000 TLoss: 1.4228745e-09 VLoss: 1.0728836e-09 Similarity: 0.110 Seconds 65.7
Step: 610000 TLoss: 1.4236196e-09 VLoss: 1.0967254e-09 Similarity: 0.113 Seconds 66.9
Step: 620000 TLoss: 1.4243645e-09 VLoss: 1.1920928e-09 Similarity: 0.117 Seconds 65.8
Step: 630000 TLoss: 1.4230607e-09 VLoss: 1.5258789e-09 Similarity: 0.120 Seconds 65.9
Step: 640000 TLoss: 1.4180316e-09 VLoss: 1.0490416e-09 Similarity: 0.124 Seconds 65.8
Step: 650000 TLoss: 1.4292074e-09 VLoss: 1.2159346e-09 Similarity: 0.127 Seconds 65.7
Step: 660000 TLoss: 1.4111398e-09 VLoss: 1.2874602e-09 Similarity: 0.131 Seconds 65.9
Step: 670000 TLoss: 1.4150513e-09 VLoss: 1.0251998e-09 Similarity: 0.135 Seconds 65.8
Step: 680000 TLoss: 1.4187767e-09 VLoss: 1.2397764e-09 Similarity: 0.138 Seconds 65.8
Step: 690000 TLoss: 1.4150513e-09 VLoss: 1.5735624e-09 Similarity: 0.142 Seconds 65.8
Step: 700000 TLoss: 1.4124436e-09 VLoss: 1.2397764e-09 Similarity: 0.146 Seconds 65.8
Step: 710000 TLoss: 1.4159827e-09 VLoss: 6.437301e-10 Similarity: 0.189 Seconds 66.9
Step: 720000 TLoss: 1.5705822e-09 VLoss: 1.3589857e-09 Similarity: 0.246 Seconds 65.7
Step: 730000 TLoss: 1.493655e-09 VLoss: 1.9073485e-09 Similarity: 0.250 Seconds 65.8
Step: 740000 TLoss: 1.474656e-09 VLoss: 1.3351439e-09 Similarity: 0.254 Seconds 65.8
Step: 750000 TLoss: 1.4690681e-09 VLoss: 1.9073485e-09 Similarity: 0.258 Seconds 65.7
Step: 760000 TLoss: 1.4789401e-09 VLoss: 1.1205672e-09 Similarity: 0.262 Seconds 65.7
Step: 770000 TLoss: 1.462735e-09 VLoss: 1.4066696e-09 Similarity: 0.266 Seconds 65.7
Step: 780000 TLoss: 0.00051325257 VLoss: 4.7683715e-11 Similarity: 0.209 Seconds 65.8
Step: 790000 TLoss: 2.8088684e-10 VLoss: 1.9073486e-10 Similarity: 0.212 Seconds 65.7
Step: 800000 TLoss: 3.341585e-10 VLoss: 2.3841856e-10 Similarity: 0.214 Seconds 65.8
Step: 810000 TLoss: 3.8016584e-10 VLoss: 6.67572e-10 Similarity: 0.216 Seconds 66.8
Step: 820000 TLoss: 4.0736045e-10 VLoss: 5.245208e-10 Similarity: 0.217 Seconds 65.7
Step: 830000 TLoss: 4.308298e-10 VLoss: 3.814697e-10 Similarity: 0.219 Seconds 65.7
Step: 840000 TLoss: 4.3716278e-10 VLoss: 3.5762784e-10 Similarity: 0.220 Seconds 65.7
Step: 850000 TLoss: 4.4349577e-10 VLoss: 7.390975e-10 Similarity: 0.221 Seconds 65.7
Step: 860000 TLoss: 4.5541665e-10 VLoss: 3.8146972e-10 Similarity: 0.223 Seconds 65.8
Step: 870000 TLoss: 4.706904e-10 VLoss: 4.291534e-10 Similarity: 0.224 Seconds 65.7
Step: 880000 TLoss: 4.716217e-10 VLoss: 4.768371e-10 Similarity: 0.225 Seconds 65.7
Step: 890000 TLoss: 4.7758214e-10 VLoss: 7.3909756e-10 Similarity: 0.227 Seconds 65.8
Step: 900000 TLoss: 4.872679e-10 VLoss: 5.245208e-10 Similarity: 0.228 Seconds 65.7
Step: 910000 TLoss: 4.8801296e-10 VLoss: 4.5299525e-10 Similarity: 0.229 Seconds 66.8
Step: 920000 TLoss: 5.003064e-10 VLoss: 7.629394e-10 Similarity: 0.230 Seconds 65.7
Step: 930000 TLoss: 5.122273e-10 VLoss: 5.00679e-10 Similarity: 0.232 Seconds 65.7
Step: 940000 TLoss: 5.0999216e-10 VLoss: 3.33786e-10 Similarity: 0.233 Seconds 65.7
Step: 950000 TLoss: 5.222856e-10 VLoss: 2.8610228e-10 Similarity: 0.234 Seconds 65.8
Step: 960000 TLoss: 5.30295e-10 VLoss: 6.1988825e-10 Similarity: 0.235 Seconds 65.8
Step: 970000 TLoss: 5.427747e-10 VLoss: 6.1988825e-10 Similarity: 0.236 Seconds 65.7
Step: 980000 TLoss: 5.3048127e-10 VLoss: 5.245208e-10 Similarity: 0.237 Seconds 65.8
Step: 990000 TLoss: 5.479901e-10 VLoss: 5.483627e-10 Similarity: 0.238 Seconds 65.7
Step: 1000000 TLoss: 5.507841e-10 VLoss: 4.5299525e-10 Similarity: 0.239 Seconds 65.7

We have trained embeddings!

In [18]:
vectors = sess.run(embeddings)
vectors = pd.DataFrame(vectors, index=vocab.index)

Demonstrate similarity

In [20]:
from sklearn.metrics.pairwise import cosine_similarity

print('Similarity:')
print('   computer to mouse =', cosine_similarity(vectors.loc[['computer']], vectors.loc[['mouse']])[0][0])
print('   cat to mouse =', cosine_similarity(vectors.loc[['cat']], vectors.loc[['mouse']])[0][0])
print('   dog to mouse =', cosine_similarity(vectors.loc[['dog']], vectors.loc[['mouse']])[0][0])
Similarity:
   computer to mouse = 0.05870525
   cat to mouse = 0.052366085
   dog to mouse = -0.009641118

Comments

Comments powered by Disqus