Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 222 additions & 0 deletions BETO_getting_started.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "BETO_getting_started.ipynb",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/mrm8488/beto/blob/master/BETO_getting_started.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "p2LNuj2a59BL",
"colab_type": "text"
},
"source": [
"#Getting started with BETO\n",
"> Colab by [Manuel Romero](https://twitter.com/mrm8488)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SkesDoqe6Dku",
"colab_type": "text"
},
"source": [
"# BETO: Spanish BERT\n",
"\n",
"BETO is a [BERT model](https://github.com/google-research/bert) trained on a [big Spanish corpus](https://github.com/josecannete/spanish-corpora). BETO is of size similar to a BERT-Base and was trained with the Whole Word Masking technique. Below you find Tensorflow and Pytorch checkpoints for the uncased and cased versions, as well as some results for Spanish benchmarks comparing BETO with [Multilingual BERT](https://github.com/google-research/bert/blob/master/multilingual.md) as well as other (not BERT-based) models."
]
},
{
"cell_type": "code",
"metadata": {
"id": "1q1CU0O06ELS",
"colab_type": "code",
"colab": {}
},
"source": [
"!pip install transformers > /dev/null"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "LWy074dJ6HW3",
"colab_type": "code",
"colab": {}
},
"source": [
"import torch"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Avka9_xF6Ja2",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 64
},
"outputId": "d25944dd-cda3-453c-95c3-4e37c4405a5e"
},
"source": [
"from transformers import BertTokenizer, BertForMaskedLM"
],
"execution_count": 3,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<p style=\"color: red;\">\n",
"The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.<br>\n",
"We recommend you <a href=\"https://www.tensorflow.org/guide/migrate\" target=\"_blank\">upgrade</a> now \n",
"or ensure your notebook will continue to use TensorFlow 1.x via the <code>%tensorflow_version 1.x</code> magic:\n",
"<a href=\"https://colab.research.google.com/notebooks/tensorflow_version.ipynb\" target=\"_blank\">more info</a>.</p>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "bwS9Zu8b6Nd8",
"colab_type": "code",
"colab": {}
},
"source": [
"tokenizer = BertTokenizer.from_pretrained(\"dccuchile/cased\")\n",
"\n",
"model = BertForMaskedLM.from_pretrained(\"dccuchile/cased\")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "w3nzzyjG6OUA",
"colab_type": "code",
"colab": {}
},
"source": [
"model.eval()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "eTmaZgv06STH",
"colab_type": "code",
"colab": {}
},
"source": [
"text = \"[CLS] Para solucionar los [MASK] de Chile, el presidente debe [MASK] de inmediato. [SEP]\""
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ZqsADzQI6UcH",
"colab_type": "code",
"colab": {}
},
"source": [
"masked_indxs = (4,11)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "rvWXuR6b6Vft",
"colab_type": "code",
"colab": {}
},
"source": [
"indexed_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))\n",
"tokens_tensor = torch.tensor([indexed_tokens])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "M9eUOzvE6XWq",
"colab_type": "code",
"colab": {}
},
"source": [
"predictions = model(tokens_tensor)[0]"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "tUvd9IxG6aoz",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
},
"outputId": "e608ecdc-93c3-4b2d-de8b-f43be99a0cce"
},
"source": [
"for i, midx in enumerate(masked_indxs):\n",
" idxs = torch.argsort(predictions[0,midx], descending=True)\n",
" predicted_token = tokenizer.convert_ids_to_tokens(idxs[:5])\n",
" print('MASKS', i, ':', predicted_token)"
],
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"text": [
"MASKS 0 : ['enz', 'Civiles', '##fasis', 'musulmanes', 'Inicia']\n",
"MASKS 1 : ['##ANCI', '##cales', '##uez', '##idi', 'enz']\n"
],
"name": "stdout"
}
]
}
]
}