Commit aea54e64 authored by Haghighatkhah, P.'s avatar Haghighatkhah, P.
Browse files

Bert embedding extraction, added gitignore

parents
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
*.zip
# C extensions
*.so
# Folders
bert_model/
uncased_L-12_H-768_A-12/
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# IPython Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# dotenv
.env
# virtualenv
venv/
ENV/
# Spyder project settings
.spyderproject
# Rope project settings
.ropeproject
### VirtualEnv template
# Virtualenv
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
.Python
[Bb]in
[Ii]nclude
[Ll]ib
[Ll]ib64
[Ll]ocal
[Ss]cripts
pyvenv.cfg
.venv
pip-selfcheck.json
### JetBrains template
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff:
.idea/workspace.xml
.idea/tasks.xml
.idea/dictionaries
.idea/vcs.xml
.idea/jsLibraryMappings.xml
# Sensitive or high-churn files:
.idea/dataSources.ids
.idea/dataSources.xml
.idea/dataSources.local.xml
.idea/sqlDataSources.xml
.idea/dynamic.xml
.idea/uiDesigner.xml
# Gradle:
.idea/gradle.xml
.idea/libraries
# Mongo Explorer plugin:
.idea/mongoSettings.xml
.idea/
## File-based project format:
*.iws
## Plugin-specific files:
# IntelliJ
/out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
import torch
from transformers import BertTokenizer, BertModel, BertConfig, BertForMaskedLM
import helper
def load_model(model_name):
# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained(model_name)
# Load pre-trained model (weights)
model = BertModel.from_pretrained(model_name,
output_hidden_states=True, # Whether the model returns all hidden-states.
)
model.eval()
return model, tokenizer
def get_tokens_tensors(tokenized_text):
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])
return tokens_tensor
def get_segment_tensor(first_sent_len, second_sent_len=0):
seg_ids = [0] * first_sent_len + [1] * second_sent_len
return torch.tensor([seg_ids])
def convert_single_sentence_to_bert_input(sentence, tokenizer, sep_tok=' [SEP]', cls_tok='[CLS] '):
# Gets a single sentence, adds [CLS] and [SEP] to the beginning and the end of it.
# Tokenizes this marked text and returns the tokens tensor and the segment tensor.
marked_text = cls_tok + sentence + sep_tok
tokenized_text = tokenizer.tokenize(marked_text)
tokens_tensor = get_tokens_tensors(tokenized_text)
segments_tensor = get_segment_tensor(first_sent_len=tokens_tensor.shape[1])
print('size is ', segments_tensor.size())
return tokens_tensor, segments_tensor
def convert_text_to_bert_input(text, tokenizer, cls_tok='[CLS] '):
# Gets a piece of text, adds [CLS] to the beginning of it.
# Tokenizes this marked text and returns the tokens tensor and the segment tensor.
marked_text = cls_tok + text
tokenized_text = tokenizer.tokenize(marked_text)
tokens_tensor = get_tokens_tensors(tokenized_text)
segments_tensor = get_segment_tensor(first_sent_len=tokens_tensor.shape[1])
return tokens_tensor, segments_tensor
def convert_two_sentences_to_bert_input(sentence_one, sentence_two, tokenizer, sep_tok=' [SEP]', cls_tok='[CLS] '):
# Gets a single sentence, adds [CLS] to the beginning and [SEP] in between sentences.
# Tokenizes this marked text and returns the tokens tensor and the segment tensor.
marked_text = cls_tok + sentence_one + sep_tok + sentence_two
tokenized_text = tokenizer.tokenize(marked_text)
tokens_tensor = get_tokens_tensors(tokenized_text)
first_sent_len = tokenized_text.index(sep_tok.strip(' ')) + 1
second_sent_len = len(tokenized_text) - first_sent_len
segments_tensor = get_segment_tensor(first_sent_len, second_sent_len)
return tokens_tensor, segments_tensor
def get_hidden_states(tokens_tensor, segments_tensor, model):
# Runs the model given the tokens and segments tensors.
# Returns the hidden layer values in a list of tensors.
with torch.no_grad():
outputs = model(tokens_tensor, segments_tensor)
hidden_states = outputs[2]
return hidden_states
if __name__ == '__main__':
model_name = 'bert-large-uncased'
model, tokenizer = load_model(model_name)
# Needed tokens for BertTokenizer
mtok = '[MASK]'
septok = ' [SEP]'
clstok = '[CLS] '
text = """After stealing money from the bank vault, the bank robber was seen
fishing on the Mississippi river bank."""
tokens_tensor, segments_tensor = convert_single_sentence_to_bert_input(text, tokenizer)
states = get_hidden_states(tokens_tensor, segments_tensor, model)
embedding = helper.second_to_last_n_layers_concat(states, 3)
\ No newline at end of file
import torch
def get_token_embeddings(hidden_states):
# Reshapes the hidden states list [#layers, #batches, #tokens, #features] to tensor[# tokens, # layers, # features]
token_embeddings = torch.stack(hidden_states, dim=0)
token_embeddings = torch.squeeze(token_embeddings, dim=1)
token_embeddings = token_embeddings.permute(1, 0, 2)
return token_embeddings
def second_to_last_n_layers_concat(hidden_states, n=2):
# Each token embedding is represented by concatenation of -n:-1 layers of the model.
# The sentence embedding is the average of its token embeddings.
token_vecs_cat = []
token_embeddings = get_token_embeddings(hidden_states)
for token in token_embeddings:
cat_vecs = torch.cat(tuple(token[-n:-1]), dim=0)
token_vecs_cat.append(cat_vecs)
token_vecs_tensor = torch.stack(token_vecs_cat)
text_embedding = torch.mean(token_vecs_tensor, dim=0)
return text_embedding
def second_to_last_n_layers_sum(hidden_states, n=2):
# Each token embedding is represented by summation of -n:-1 layers of the model.
# The sentence embedding is the average of its token embeddings.
token_vecs_cat = []
token_embeddings = get_token_embeddings(hidden_states)
for token in token_embeddings:
sum_vecs = torch.sum(token[-n:-1], dim=0)
token_vecs_cat.append(sum_vecs)
token_vecs_tensor = torch.stack(token_vecs_cat)
text_embedding = torch.mean(token_vecs_tensor, dim=0)
return text_embedding
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment