CARVIEW |
Select Language
HTTP/2 200
content-type: application/octet-stream
x-guploader-uploadid: ABgVH89hVUEwIwBN9wHYKYvVKFLKdeQorsAfEf0AESdCHyMVns1nLLv7tQ6WZ6CYOxMumHO_
expires: Sun, 20 Jul 2025 01:25:54 GMT
date: Sun, 20 Jul 2025 00:25:54 GMT
cache-control: public, max-age=3600
last-modified: Fri, 31 May 2024 14:08:36 GMT
etag: "b47cad145e03699f0461235f9f271f39"
x-goog-generation: 1717164516533519
x-goog-metageneration: 1
x-goog-stored-content-encoding: identity
x-goog-stored-content-length: 22919
x-goog-hash: crc32c=2rhD6w==
x-goog-hash: md5=tHytFF4DaZ8EYSNfnycfOQ==
x-goog-storage-class: MULTI_REGIONAL
accept-ranges: bytes
content-length: 22919
server: UploadServer
alt-svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "hX4n9TsbGw-f"
},
"source": [
"##### Copyright 2018 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "0nbI5DtDGw-i"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9TnJztDZGw-n"
},
"source": [
"# Text classification with an RNN"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AfN3bMR5Gw-o"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/text/tutorials/text_classification_rnn\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/text/blob/master/docs/tutorials/text_classification_rnn.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/text/blob/master/docs/tutorials/text_classification_rnn.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/text/docs/tutorials/text_classification_rnn.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lUWearf0Gw-p"
},
"source": [
"This text classification tutorial trains a [recurrent neural network](https://developers.google.com/machine-learning/glossary/#recurrent_neural_network) on the [IMDB large movie review dataset](https://ai.stanford.edu/~amaas/data/sentiment/) for sentiment analysis."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_2VQo4bajwUU"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "z682XYsrjkY9"
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"import tensorflow_datasets as tfds\n",
"import tensorflow as tf\n",
"\n",
"tfds.disable_progress_bar()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1rXHa-w9JZhb"
},
"source": [
"Import `matplotlib` and create a helper function to plot graphs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Mp1Z7P9pYRSK"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"def plot_graphs(history, metric):\n",
" plt.plot(history.history[metric])\n",
" plt.plot(history.history['val_'+metric], '')\n",
" plt.xlabel(\"Epochs\")\n",
" plt.ylabel(metric)\n",
" plt.legend([metric, 'val_'+metric])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pRmMubr0jrE2"
},
"source": [
"## Setup input pipeline\n",
"\n",
"\n",
"The IMDB large movie review dataset is a *binary classification* dataset—all the reviews have either a *positive* or *negative* sentiment.\n",
"\n",
"Download the dataset using [TFDS](https://www.tensorflow.org/datasets). See the [loading text tutorial](https://www.tensorflow.org/tutorials/load_data/text) for details on how to load this sort of data manually.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SHRwRoP2nVHX"
},
"outputs": [],
"source": [
"dataset, info = tfds.load('imdb_reviews', with_info=True,\n",
" as_supervised=True)\n",
"train_dataset, test_dataset = dataset['train'], dataset['test']\n",
"\n",
"train_dataset.element_spec"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nWA4c2ir7g6p"
},
"source": [
"Initially this returns a dataset of (text, label pairs):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vd4_BGKyurao"
},
"outputs": [],
"source": [
"for example, label in train_dataset.take(1):\n",
" print('text: ', example.numpy())\n",
" print('label: ', label.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z2qVJzcEluH_"
},
"source": [
"Next shuffle the data for training and create batches of these `(text, label)` pairs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dDsCaZCDYZgm"
},
"outputs": [],
"source": [
"BUFFER_SIZE = 10000\n",
"BATCH_SIZE = 64"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VznrltNOnUc5"
},
"outputs": [],
"source": [
"train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)\n",
"test_dataset = test_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jqkvdcFv41wC"
},
"outputs": [],
"source": [
"for example, label in train_dataset.take(1):\n",
" print('texts: ', example.numpy()[:3])\n",
" print()\n",
" print('labels: ', label.numpy()[:3])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s5eWCo88voPY"
},
"source": [
"## Create the text encoder"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TFevcItw15P_"
},
"source": [
"The raw text loaded by `tfds` needs to be processed before it can be used in a model. The simplest way to process text for training is using the `TextVectorization` layer. This layer has many capabilities, but this tutorial sticks to the default behavior.\n",
"\n",
"Create the layer, and pass the dataset's text to the layer's `.adapt` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uC25Lu1Yvuqy"
},
"outputs": [],
"source": [
"VOCAB_SIZE = 1000\n",
"encoder = tf.keras.layers.TextVectorization(\n",
" max_tokens=VOCAB_SIZE)\n",
"encoder.adapt(train_dataset.map(lambda text, label: text))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IuQzVBbe3Ldu"
},
"source": [
"The `.adapt` method sets the layer's vocabulary. Here are the first 20 tokens. After the padding and unknown tokens they're sorted by frequency: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tBoyjjWg0Ac9"
},
"outputs": [],
"source": [
"vocab = np.array(encoder.get_vocabulary())\n",
"vocab[:20]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mjId5pua3jHQ"
},
"source": [
"Once the vocabulary is set, the layer can encode text into indices. The tensors of indices are 0-padded to the longest sequence in the batch (unless you set a fixed `output_sequence_length`):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RGc7C9WiwRWs"
},
"outputs": [],
"source": [
"encoded_example = encoder(example)[:3].numpy()\n",
"encoded_example"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "F5cjz0bS39IN"
},
"source": [
"With the default settings, the process is not completely reversible. There are three main reasons for that:\n",
"\n",
"1. The default value for `preprocessing.TextVectorization`'s `standardize` argument is `\"lower_and_strip_punctuation\"`.\n",
"2. The limited vocabulary size and lack of character-based fallback results in some unknown tokens."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "N_tD0QY5wXaK"
},
"outputs": [],
"source": [
"for n in range(3):\n",
" print(\"Original: \", example[n].numpy())\n",
" print(\"Round-trip: \", \" \".join(vocab[encoded_example[n]]))\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bjUqGVBxGw-t"
},
"source": [
"## Create the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W7zsmInBOCPO"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bgs6nnSTGw-t"
},
"source": [
"Above is a diagram of the model. \n",
"\n",
"1. This model can be build as a `tf.keras.Sequential`.\n",
"\n",
"2. The first layer is the `encoder`, which converts the text to a sequence of token indices.\n",
"\n",
"3. After the encoder is an embedding layer. An embedding layer stores one vector per word. When called, it converts the sequences of word indices to sequences of vectors. These vectors are trainable. After training (on enough data), words with similar meanings often have similar vectors.\n",
"\n",
" This index-lookup is much more efficient than the equivalent operation of passing a one-hot encoded vector through a `tf.keras.layers.Dense` layer.\n",
"\n",
"4. A recurrent neural network (RNN) processes sequence input by iterating through the elements. RNNs pass the outputs from one timestep to their input on the next timestep.\n",
"\n",
" The `tf.keras.layers.Bidirectional` wrapper can also be used with an RNN layer. This propagates the input forward and backwards through the RNN layer and then concatenates the final output. \n",
"\n",
" * The main advantage of a bidirectional RNN is that the signal from the beginning of the input doesn't need to be processed all the way through every timestep to affect the output. \n",
"\n",
" * The main disadvantage of a bidirectional RNN is that you can't efficiently stream predictions as words are being added to the end.\n",
"\n",
"5. After the RNN has converted the sequence to a single vector the two `layers.Dense` do some final processing, and convert from this vector representation to a single logit as the classification output. \n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V4fodCI7soQi"
},
"source": [
"The code to implement this is below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LwfoBkmRYcP3"
},
"outputs": [],
"source": [
"model = tf.keras.Sequential([\n",
" encoder,\n",
" tf.keras.layers.Embedding(\n",
" input_dim=len(encoder.get_vocabulary()),\n",
" output_dim=64,\n",
" # Use masking to handle the variable sequence lengths\n",
" mask_zero=True),\n",
" tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),\n",
" tf.keras.layers.Dense(64, activation='relu'),\n",
" tf.keras.layers.Dense(1)\n",
"])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QIGmIGkkouUb"
},
"source": [
"Please note that Keras sequential model is used here since all the layers in the model only have single input and produce single output. In case you want to use stateful RNN layer, you might want to build your model with Keras functional API or model subclassing so that you can retrieve and reuse the RNN layer states. Please check [Keras RNN guide](https://www.tensorflow.org/guide/keras/rnn#rnn_state_reuse) for more details."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kF-PsCk1LwjY"
},
"source": [
"The embedding layer [uses masking](https://www.tensorflow.org/guide/keras/masking_and_padding) to handle the varying sequence-lengths. All the layers after the `Embedding` support masking:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "87a8-CwfKebw"
},
"outputs": [],
"source": [
"print([layer.supports_masking for layer in model.layers])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZlS0iaUIWLpI"
},
"source": [
"To confirm that this works as expected, evaluate a sentence twice. First, alone so there's no padding to mask:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "O41gw3KfWHus"
},
"outputs": [],
"source": [
"# predict on a sample text without padding.\n",
"\n",
"sample_text = ('The movie was cool. The animation and the graphics '\n",
" 'were out of this world. I would recommend this movie.')\n",
"predictions = model.predict(np.array([sample_text]))\n",
"print(predictions[0])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K0VQmGnEWcuz"
},
"source": [
"Now, evaluate it again in a batch with a longer sentence. The result should be identical:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UIgpuTeFNDzq"
},
"outputs": [],
"source": [
"# predict on a sample text with padding\n",
"\n",
"padding = \"the \" * 2000\n",
"predictions = model.predict(np.array([sample_text, padding]))\n",
"print(predictions[0])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sRI776ZcH3Tf"
},
"source": [
"Compile the Keras model to configure the training process:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kj2xei41YZjC"
},
"outputs": [],
"source": [
"model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
" optimizer=tf.keras.optimizers.Adam(1e-4),\n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zIwH3nto596k"
},
"source": [
"## Train the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hw86wWS4YgR2"
},
"outputs": [],
"source": [
"history = model.fit(train_dataset, epochs=10,\n",
" validation_data=test_dataset,\n",
" validation_steps=30)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BaNbXi43YgUT"
},
"outputs": [],
"source": [
"test_loss, test_acc = model.evaluate(test_dataset)\n",
"\n",
"print('Test Loss:', test_loss)\n",
"print('Test Accuracy:', test_acc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OZmwt_mzaQJk"
},
"outputs": [],
"source": [
"plt.figure(figsize=(16, 8))\n",
"plt.subplot(1, 2, 1)\n",
"plot_graphs(history, 'accuracy')\n",
"plt.ylim(None, 1)\n",
"plt.subplot(1, 2, 2)\n",
"plot_graphs(history, 'loss')\n",
"plt.ylim(0, None)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DwSE_386uhxD"
},
"source": [
"Run a prediction on a new sentence:\n",
"\n",
"If the prediction is \u003e= 0.0, it is positive else it is negative."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZXgfQSgRW6zU"
},
"outputs": [],
"source": [
"sample_text = ('The movie was cool. The animation and the graphics '\n",
" 'were out of this world. I would recommend this movie.')\n",
"predictions = model.predict(np.array([sample_text]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7g1evcaRpTKm"
},
"source": [
"## Stack two or more LSTM layers\n",
"\n",
"Keras recurrent layers have two available modes that are controlled by the `return_sequences` constructor argument:\n",
"\n",
"* If `False` it returns only the last output for each input sequence (a 2D tensor of shape (batch_size, output_features)). This is the default, used in the previous model.\n",
"\n",
"* If `True` the full sequences of successive outputs for each timestep is returned (a 3D tensor of shape `(batch_size, timesteps, output_features)`).\n",
"\n",
"Here is what the flow of information looks like with `return_sequences=True`:\n",
"\n",
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wbSClCrG1z8l"
},
"source": [
"The interesting thing about using an `RNN` with `return_sequences=True` is that the output still has 3-axes, like the input, so it can be passed to another RNN layer, like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jo1jjO3vn0jo"
},
"outputs": [],
"source": [
"model = tf.keras.Sequential([\n",
" encoder,\n",
" tf.keras.layers.Embedding(len(encoder.get_vocabulary()), 64, mask_zero=True),\n",
" tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=True)),\n",
" tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),\n",
" tf.keras.layers.Dense(64, activation='relu'),\n",
" tf.keras.layers.Dropout(0.5),\n",
" tf.keras.layers.Dense(1)\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hEPV5jVGp-is"
},
"outputs": [],
"source": [
"model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
" optimizer=tf.keras.optimizers.Adam(1e-4),\n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LeSE-YjdqAeN"
},
"outputs": [],
"source": [
"history = model.fit(train_dataset, epochs=10,\n",
" validation_data=test_dataset,\n",
" validation_steps=30)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_LdwilM1qPM3"
},
"outputs": [],
"source": [
"test_loss, test_acc = model.evaluate(test_dataset)\n",
"\n",
"print('Test Loss:', test_loss)\n",
"print('Test Accuracy:', test_acc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ykUKnAoqbycW"
},
"outputs": [],
"source": [
"# predict on a sample text without padding.\n",
"\n",
"sample_text = ('The movie was not good. The animation and the graphics '\n",
" 'were terrible. I would not recommend this movie.')\n",
"predictions = model.predict(np.array([sample_text]))\n",
"print(predictions)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_YYub0EDtwCu"
},
"outputs": [],
"source": [
"plt.figure(figsize=(16, 6))\n",
"plt.subplot(1, 2, 1)\n",
"plot_graphs(history, 'accuracy')\n",
"plt.subplot(1, 2, 2)\n",
"plot_graphs(history, 'loss')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9xvpE3BaGw_V"
},
"source": [
"Check out other existing recurrent layers such as [GRU layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers/GRU).\n",
"\n",
"If you're interested in building custom RNNs, see the [Keras RNN Guide](https://www.tensorflow.org/guide/keras/rnn).\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "text_classification_rnn.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}