{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Word2Vec with Skip-Gram and TensorFlow\n", "\n", "This is a tutorial and a basic example for getting started with word2vec model by [Mikolov et al](https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). It is used for learning vector representations of words, called \"Words Embeddings\". For more information about Embeddings, read my previous post. \n", "\n", "### The word2vec model can be trained with two different word representations:\n", "- Continuous Bag-of-Words (CBOW): predicts target words (e.g. 'mat') from source context words ('the cat sits on the')\n", "- Skip-Gram: predicts source context-words from the target words\n", "\n", "### Skip-Gram tends to do better and this tutorial will implement a word2vec with skip-grams.\n", "The goal of the model is to train it's embeddings layer in a way that similar by meaning words are close to each other in their N-dimensional vector representation. The model has two layers: the embeddings layer and a linear layer. Because of the last layer is linear, the distance between embedding vectors for words is linearly related to the distance in the meaning of those words. In other words, we are able to do such mathematical operations with the vectors: [king] - [man] + [woman] ~= [queen]\n", "" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2018-01-20T06:21:33.337119Z", "start_time": "2018-01-20T06:21:32.146930Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "env: CUDA_VISIBLE_DEVICES=0\n" ] } ], "source": [ "%env CUDA_VISIBLE_DEVICES=0\n", "import time\n", "import numpy as np\n", "import pandas as pd\n", "import tensorflow as tf\n", "import sklearn\n", "import nltk" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Dataset\n", "\n", "To train a word2vec model, we need a large text corpus. This example uses the text from the \"20 newsgroups dataset\". The dataset contains 11314 messages form a message board with corresponding labels for its topics. We just merge all messages together and ignore the labels. In practice, it's better to use a larger corpus and to have a domain-specific text. Lowering the case of the text is optional and recommended when working with a small corpus." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2018-01-20T06:21:33.687569Z", "start_time": "2018-01-20T06:21:33.338822Z" } }, "outputs": [ { "data": { "text/plain": [ "'umd.edu\\norganization: university of maryland, college park\\nlines: 15\\n\\n i was wondering if anyone out there could enlighten me on this car i saw\\nthe other day. it was a 2-door sports car, looked to be from the late 60s/\\nearly 70s. it was called a bric'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.datasets import fetch_20newsgroups\n", "data = fetch_20newsgroups()\n", "\n", "text = ' '.join(data.data).lower()\n", "text[100:350]" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2018-01-20T01:46:41.573105Z", "start_time": "2018-01-20T01:46:41.487327Z" } }, "source": [ "# Sentence Tokenize\n", "\n", "The skip grams will work better if they are created from sentenced text. ```nltk.sent_tokenize``` will break a string to a list of sentences." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2018-01-20T06:21:42.273360Z", "start_time": "2018-01-20T06:21:33.689146Z" } }, "outputs": [ { "data": { "text/plain": [ "173268" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sentences_text = nltk.sent_tokenize(text)\n", "len(sentences_text)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Word Tokenize\n", "\n", "Next, break all sentences to tokens (words) with ```nltk.word_tokenize```." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2018-01-20T06:22:19.027607Z", "start_time": "2018-01-20T06:21:42.274908Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['please', 'send', 'a', 'brief', 'message', 'detailing', 'your', 'experiences', 'with', 'the', 'procedure', '.']\n" ] } ], "source": [ "sentences = [nltk.word_tokenize(s) for s in sentences_text]\n", "print(sentences[10])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Vocabulary (unique words)\n", "\n", "In this example, we filter words who are used less than 5 times in the text, stop words and punctuations." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2018-01-20T06:22:20.011301Z", "start_time": "2018-01-20T06:22:19.029119Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The vocabulary has: 34016 words\n" ] } ], "source": [ "from collections import Counter\n", "from string import punctuation\n", "from nltk.corpus import stopwords\n", "\n", "min_count = 5\n", "puncs = set(punctuation)\n", "stops = set(stopwords.words('english'))\n", "\n", "flat_words = []\n", "for sentence in sentences:\n", " flat_words += sentence\n", " \n", "counts = Counter(list(flat_words))\n", "counts = pd.DataFrame(counts.most_common())\n", "counts.columns = ['word', 'count']\n", "\n", "counts = counts[counts['count'] >= min_count]\n", "counts = counts[~counts['word'].isin(puncs)]\n", "counts = counts[~counts['word'].isin(stops)]\n", "\n", "\n", "vocab = pd.Series(range(len(counts)), index=counts['word']).sort_index()\n", "\n", "print('The vocabulary has:', len(vocab), 'words')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Filter tokens not in vocabulary\n", "\n", "Some words were excluded from the vocabulary because they are very rare or too common to present value. We have to remove them from our sentences." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2018-01-20T06:22:23.291119Z", "start_time": "2018-01-20T06:22:20.013049Z" } }, "outputs": [], "source": [ "filtered_sentences = []\n", "\n", "for sentence in sentences:\n", " sentence = [word for word in sentence if word in vocab.index]\n", " if len(sentence):\n", " filtered_sentences.append(sentence)\n", "sentences = filtered_sentences" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Transform the words to integer indexes" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2018-01-20T06:23:25.362175Z", "start_time": "2018-01-20T06:22:23.292810Z" } }, "outputs": [], "source": [ "for i, sentence in enumerate(sentences):\n", " sentences[i] = [vocab.loc[word] for word in sentence]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Create Skip-Gram dataset" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2018-01-20T06:23:37.625844Z", "start_time": "2018-01-20T06:23:25.363656Z" }, "scrolled": false }, "outputs": [ { "data": { "text/html": [ "
\n", " | x | \n", "y | \n", "
---|---|---|
0 | \n", "5816 | \n", "4 | \n", "
1 | \n", "5816 | \n", "122 | \n", "
2 | \n", "5816 | \n", "6 | \n", "
3 | \n", "5816 | \n", "159 | \n", "
4 | \n", "4 | \n", "122 | \n", "