HTTP/2 200
content-type: application/octet-stream
x-guploader-uploadid: ABgVH88sofP82xfmy1A9lxuwIYOkgHFoMT0rNR0X1wfguwn_bURhYXyjDJn_Apq0PkccMEwN
expires: Tue, 15 Jul 2025 20:55:26 GMT
date: Tue, 15 Jul 2025 19:55:26 GMT
cache-control: public, max-age=3600
last-modified: Tue, 27 Aug 2024 02:23:02 GMT
etag: "2e0647c2f54cbda9df96a7fe486101dc"
x-goog-generation: 1724725382116374
x-goog-metageneration: 1
x-goog-stored-content-encoding: identity
x-goog-stored-content-length: 136190
x-goog-hash: crc32c=3wCWPQ==
x-goog-hash: md5=LgZHwvVMvanflqf+SGEB3A==
x-goog-storage-class: MULTI_REGIONAL
accept-ranges: bytes
content-length: 136190
server: UploadServer
alt-svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "nibpbUnTsxTd"
},
"source": [
"##### Copyright 2018 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cellView": "form",
"execution": {
"iopub.execute_input": "2024-08-27T01:20:39.001977Z",
"iopub.status.busy": "2024-08-27T01:20:39.001732Z",
"iopub.status.idle": "2024-08-27T01:20:39.005625Z",
"shell.execute_reply": "2024-08-27T01:20:39.004995Z"
},
"id": "tXAbWHtqs1Y2"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HTgMAvQq-PU_"
},
"source": [
"# Ragged tensors\n",
"\n",
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5DP8XNP-6zlu"
},
"source": [
"**API Documentation:** [`tf.RaggedTensor`](https://www.tensorflow.org/api_docs/python/tf/RaggedTensor) [`tf.ragged`](https://www.tensorflow.org/api_docs/python/tf/ragged)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cDIUjj07-rQg"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:39.009129Z",
"iopub.status.busy": "2024-08-27T01:20:39.008881Z",
"iopub.status.idle": "2024-08-27T01:20:42.614222Z",
"shell.execute_reply": "2024-08-27T01:20:42.613465Z"
},
"id": "KKvdSorS-pDD"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: tensorflow in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (2.17.0)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: absl-py>=1.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (2.1.0)\r\n",
"Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (1.6.3)\r\n",
"Requirement already satisfied: flatbuffers>=24.3.25 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (24.3.25)\r\n",
"Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (0.6.0)\r\n",
"Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (0.2.0)\r\n",
"Requirement already satisfied: h5py>=3.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (3.11.0)\r\n",
"Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (18.1.1)\r\n",
"Requirement already satisfied: ml-dtypes<0.5.0,>=0.3.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (0.4.0)\r\n",
"Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (3.3.0)\r\n",
"Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (24.1)\r\n",
"Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (3.20.3)\r\n",
"Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (2.32.3)\r\n",
"Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (73.0.1)\r\n",
"Requirement already satisfied: six>=1.12.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (1.16.0)\r\n",
"Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (2.4.0)\r\n",
"Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (4.12.2)\r\n",
"Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (1.16.0)\r\n",
"Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (1.66.0)\r\n",
"Requirement already satisfied: tensorboard<2.18,>=2.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (2.17.1)\r\n",
"Requirement already satisfied: keras>=3.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (3.5.0)\r\n",
"Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (0.37.1)\r\n",
"Requirement already satisfied: numpy<2.0.0,>=1.23.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow) (1.26.4)\r\n",
"Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tensorflow) (0.43.0)\r\n",
"Requirement already satisfied: rich in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.2.0->tensorflow) (13.8.0)\r\n",
"Requirement already satisfied: namex in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.2.0->tensorflow) (0.0.8)\r\n",
"Requirement already satisfied: optree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.2.0->tensorflow) (0.12.1)\r\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow) (3.3.2)\r\n",
"Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow) (3.8)\r\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow) (2.2.2)\r\n",
"Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow) (2024.7.4)\r\n",
"Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.18,>=2.17->tensorflow) (3.7)\r\n",
"Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.18,>=2.17->tensorflow) (0.7.2)\r\n",
"Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.18,>=2.17->tensorflow) (3.0.4)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.18,>=2.17->tensorflow) (8.4.0)\r\n",
"Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.18,>=2.17->tensorflow) (2.1.5)\r\n",
"Requirement already satisfied: markdown-it-py>=2.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.2.0->tensorflow) (3.0.0)\r\n",
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.2.0->tensorflow) (2.18.0)\r\n",
"Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.18,>=2.17->tensorflow) (3.20.1)\r\n",
"Requirement already satisfied: mdurl~=0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.2.0->tensorflow) (0.1.2)\r\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-08-27 01:20:40.536630: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-08-27 01:20:40.557820: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-08-27 01:20:40.563984: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
]
}
],
"source": [
"!pip install --pre -U tensorflow\n",
"import math\n",
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pxi0m_yf-te5"
},
"source": [
"## Overview\n",
"\n",
"Your data comes in many shapes; your tensors should too. *Ragged tensors* are the TensorFlow equivalent of nested variable-length lists. They make it easy to store and process data with non-uniform shapes, including:\n",
"\n",
"- Variable-length features, such as the set of actors in a movie.\n",
"- Batches of variable-length sequential inputs, such as sentences or video clips.\n",
"- Hierarchical inputs, such as text documents that are subdivided into sections, paragraphs, sentences, and words.\n",
"- Individual fields in structured inputs, such as protocol buffers.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1mhU_qY3_mla"
},
"source": [
"### What you can do with a ragged tensor\n",
"\n",
"Ragged tensors are supported by more than a hundred TensorFlow operations, including math operations (such as `tf.add` and `tf.reduce_mean`), array operations (such as `tf.concat` and `tf.tile`), string manipulation ops (such as `tf.strings.substr`), control flow operations (such as `tf.while_loop` and `tf.map_fn`), and many others:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:42.619233Z",
"iopub.status.busy": "2024-08-27T01:20:42.618569Z",
"iopub.status.idle": "2024-08-27T01:20:45.447962Z",
"shell.execute_reply": "2024-08-27T01:20:45.447223Z"
},
"id": "vGmJGSf_-PVB"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
"I0000 00:00:1724721643.161122 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721643.164924 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721643.168750 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721643.172328 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721643.183636 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721643.187178 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721643.190668 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721643.194028 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721643.197422 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721643.201089 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721643.204485 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721643.207934 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.469725 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.471772 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.473847 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.475930 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.477949 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.479867 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.481845 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.483831 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.486309 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.488220 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.490205 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.492181 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.532055 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.534019 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.536012 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.538150 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.540140 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.542032 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.543963 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.545991 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.547989 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.551435 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.553772 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1724721644.556197 10064 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"
\n",
"tf.Tensor([2.25 nan 5.33333333 6. nan], shape=(5,), dtype=float64)\n",
"\n",
"\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"digits = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])\n",
"words = tf.ragged.constant([[\"So\", \"long\"], [\"thanks\", \"for\", \"all\", \"the\", \"fish\"]])\n",
"print(tf.add(digits, 3))\n",
"print(tf.reduce_mean(digits, axis=1))\n",
"print(tf.concat([digits, [[5, 3]]], axis=0))\n",
"print(tf.tile(digits, [1, 2]))\n",
"print(tf.strings.substr(words, 0, 2))\n",
"print(tf.map_fn(tf.math.square, digits))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Pt-5OIc8-PVG"
},
"source": [
"There are also a number of methods and operations that are\n",
"specific to ragged tensors, including factory methods, conversion methods,\n",
"and value-mapping operations.\n",
"For a list of supported ops, see the **`tf.ragged` package\n",
"documentation**."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r8fjGgf3B_6z"
},
"source": [
"Ragged tensors are supported by many TensorFlow APIs, including [Keras](https://www.tensorflow.org/guide/keras), [Datasets](https://www.tensorflow.org/guide/data), [tf.function](https://www.tensorflow.org/guide/function), [SavedModels](https://www.tensorflow.org/guide/saved_model), and [tf.Example](https://www.tensorflow.org/tutorials/load_data/tfrecord). For more information, check the section on **TensorFlow APIs** below."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aTXLjQlcHP8a"
},
"source": [
"As with normal tensors, you can use Python-style indexing to access specific slices of a ragged tensor. For more information, refer to the section on **Indexing** below."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:45.451689Z",
"iopub.status.busy": "2024-08-27T01:20:45.451341Z",
"iopub.status.idle": "2024-08-27T01:20:45.459680Z",
"shell.execute_reply": "2024-08-27T01:20:45.458991Z"
},
"id": "n8YMKXpI-PVH"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([3 1 4 1], shape=(4,), dtype=int32)\n"
]
}
],
"source": [
"print(digits[0]) # First row"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:45.462918Z",
"iopub.status.busy": "2024-08-27T01:20:45.462481Z",
"iopub.status.idle": "2024-08-27T01:20:45.790498Z",
"shell.execute_reply": "2024-08-27T01:20:45.789793Z"
},
"id": "Awi8i9q5_DuX"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(digits[:, :2]) # First two values in each row."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:45.793941Z",
"iopub.status.busy": "2024-08-27T01:20:45.793671Z",
"iopub.status.idle": "2024-08-27T01:20:46.116769Z",
"shell.execute_reply": "2024-08-27T01:20:46.116013Z"
},
"id": "sXgQtTcgHHMR"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(digits[:, -2:]) # Last two values in each row."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6FU5T_-8-PVK"
},
"source": [
"And just like normal tensors, you can use Python arithmetic and comparison operators to perform elementwise operations. For more information, check the section on **Overloaded operators** below."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.120223Z",
"iopub.status.busy": "2024-08-27T01:20:46.119966Z",
"iopub.status.idle": "2024-08-27T01:20:46.124141Z",
"shell.execute_reply": "2024-08-27T01:20:46.123517Z"
},
"id": "2tdUEtb7-PVL"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(digits + 3)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.127030Z",
"iopub.status.busy": "2024-08-27T01:20:46.126779Z",
"iopub.status.idle": "2024-08-27T01:20:46.156213Z",
"shell.execute_reply": "2024-08-27T01:20:46.155549Z"
},
"id": "X-bxG0nc_Nmf"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(digits + tf.ragged.constant([[1, 2, 3, 4], [], [5, 6, 7], [8], []]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2tsw8mN0ESIT"
},
"source": [
"If you need to perform an elementwise transformation to the values of a `RaggedTensor`, you can use `tf.ragged.map_flat_values`, which takes a function plus one or more arguments, and applies the function to transform the `RaggedTensor`'s values."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.159269Z",
"iopub.status.busy": "2024-08-27T01:20:46.159039Z",
"iopub.status.idle": "2024-08-27T01:20:46.163468Z",
"shell.execute_reply": "2024-08-27T01:20:46.162866Z"
},
"id": "pvt5URbdEt-D"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"times_two_plus_one = lambda x: x * 2 + 1\n",
"print(tf.ragged.map_flat_values(times_two_plus_one, digits))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HNxF6_QKAzkl"
},
"source": [
"Ragged tensors can be converted to nested Python `list`s and NumPy `array`s:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.166527Z",
"iopub.status.busy": "2024-08-27T01:20:46.166288Z",
"iopub.status.idle": "2024-08-27T01:20:46.172822Z",
"shell.execute_reply": "2024-08-27T01:20:46.172260Z"
},
"id": "A5NHb8ViA9dt"
},
"outputs": [
{
"data": {
"text/plain": [
"[[3, 1, 4, 1], [], [5, 9, 2], [6], []]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"digits.to_list()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.176012Z",
"iopub.status.busy": "2024-08-27T01:20:46.175713Z",
"iopub.status.idle": "2024-08-27T01:20:46.180140Z",
"shell.execute_reply": "2024-08-27T01:20:46.179546Z"
},
"id": "2o1wogVyA6Yp"
},
"outputs": [
{
"data": {
"text/plain": [
"array([array([3, 1, 4, 1], dtype=int32), array([], dtype=int32),\n",
" array([5, 9, 2], dtype=int32), array([6], dtype=int32),\n",
" array([], dtype=int32)], dtype=object)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"digits.numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7M5RHOgp-PVN"
},
"source": [
"### Constructing a ragged tensor\n",
"\n",
"The simplest way to construct a ragged tensor is using `tf.ragged.constant`, which builds the `RaggedTensor` corresponding to a given nested Python `list` or NumPy `array`:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.183298Z",
"iopub.status.busy": "2024-08-27T01:20:46.182912Z",
"iopub.status.idle": "2024-08-27T01:20:46.187860Z",
"shell.execute_reply": "2024-08-27T01:20:46.187252Z"
},
"id": "yhgKMozw-PVP"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"sentences = tf.ragged.constant([\n",
" [\"Let's\", \"build\", \"some\", \"ragged\", \"tensors\", \"!\"],\n",
" [\"We\", \"can\", \"use\", \"tf.ragged.constant\", \".\"]])\n",
"print(sentences)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.190814Z",
"iopub.status.busy": "2024-08-27T01:20:46.190587Z",
"iopub.status.idle": "2024-08-27T01:20:46.195872Z",
"shell.execute_reply": "2024-08-27T01:20:46.195189Z"
},
"id": "TW1g7eE2ee8M"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"paragraphs = tf.ragged.constant([\n",
" [['I', 'have', 'a', 'cat'], ['His', 'name', 'is', 'Mat']],\n",
" [['Do', 'you', 'want', 'to', 'come', 'visit'], [\"I'm\", 'free', 'tomorrow']],\n",
"])\n",
"print(paragraphs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SPLn5xHn-PVR"
},
"source": [
"Ragged tensors can also be constructed by pairing flat *values* tensors with *row-partitioning* tensors indicating how those values should be divided into rows, using factory classmethods such as `tf.RaggedTensor.from_value_rowids`, `tf.RaggedTensor.from_row_lengths`, and `tf.RaggedTensor.from_row_splits`.\n",
"\n",
"#### `tf.RaggedTensor.from_value_rowids`\n",
"\n",
"If you know which row each value belongs to, then you can build a `RaggedTensor` using a `value_rowids` row-partitioning tensor:\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.198953Z",
"iopub.status.busy": "2024-08-27T01:20:46.198727Z",
"iopub.status.idle": "2024-08-27T01:20:46.216440Z",
"shell.execute_reply": "2024-08-27T01:20:46.215806Z"
},
"id": "SEvcPUcl-PVS"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(tf.RaggedTensor.from_value_rowids(\n",
" values=[3, 1, 4, 1, 5, 9, 2],\n",
" value_rowids=[0, 0, 0, 0, 2, 2, 3]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RBQh8sYc-PVV"
},
"source": [
"#### `tf.RaggedTensor.from_row_lengths`\n",
"\n",
"If you know how long each row is, then you can use a `row_lengths` row-partitioning tensor:\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.219589Z",
"iopub.status.busy": "2024-08-27T01:20:46.219268Z",
"iopub.status.idle": "2024-08-27T01:20:46.227093Z",
"shell.execute_reply": "2024-08-27T01:20:46.226462Z"
},
"id": "LBY81WXl-PVW"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(tf.RaggedTensor.from_row_lengths(\n",
" values=[3, 1, 4, 1, 5, 9, 2],\n",
" row_lengths=[4, 0, 2, 1]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8p5V8_Iu-PVa"
},
"source": [
"#### `tf.RaggedTensor.from_row_splits`\n",
"\n",
"If you know the index where each row starts and ends, then you can use a `row_splits` row-partitioning tensor:\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.230120Z",
"iopub.status.busy": "2024-08-27T01:20:46.229854Z",
"iopub.status.idle": "2024-08-27T01:20:46.238587Z",
"shell.execute_reply": "2024-08-27T01:20:46.237973Z"
},
"id": "FwizuqZI-PVb"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(tf.RaggedTensor.from_row_splits(\n",
" values=[3, 1, 4, 1, 5, 9, 2],\n",
" row_splits=[0, 4, 4, 6, 7]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E-9imo8DhwuA"
},
"source": [
"See the `tf.RaggedTensor` class documentation for a full list of factory methods.\n",
"\n",
"Note: By default, these factory methods add assertions that the row partition tensor is well-formed and consistent with the number of values. The `validate=False` parameter can be used to skip these checks if you can guarantee that the inputs are well-formed and consistent."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YQAOsT1_-PVg"
},
"source": [
"### What you can store in a ragged tensor\n",
"\n",
"As with normal `Tensor`s, the values in a `RaggedTensor` must all have the same\n",
"type; and the values must all be at the same nesting depth (the *rank* of the\n",
"tensor):"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.241557Z",
"iopub.status.busy": "2024-08-27T01:20:46.241318Z",
"iopub.status.idle": "2024-08-27T01:20:46.245967Z",
"shell.execute_reply": "2024-08-27T01:20:46.245305Z"
},
"id": "SqbPBd_w-PVi"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(tf.ragged.constant([[\"Hi\"], [\"How\", \"are\", \"you\"]])) # ok: type=string, rank=2"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.248931Z",
"iopub.status.busy": "2024-08-27T01:20:46.248691Z",
"iopub.status.idle": "2024-08-27T01:20:46.253486Z",
"shell.execute_reply": "2024-08-27T01:20:46.252915Z"
},
"id": "83ZCSJnQAWAf"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(tf.ragged.constant([[[1, 2], [3]], [[4, 5]]])) # ok: type=int32, rank=3"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.256304Z",
"iopub.status.busy": "2024-08-27T01:20:46.256032Z",
"iopub.status.idle": "2024-08-27T01:20:46.259922Z",
"shell.execute_reply": "2024-08-27T01:20:46.259376Z"
},
"id": "ewA3cISdDfmP"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Can't convert Python sequence with mixed types to Tensor.\n"
]
}
],
"source": [
"try:\n",
" tf.ragged.constant([[\"one\", \"two\"], [3, 4]]) # bad: multiple types\n",
"except ValueError as exception:\n",
" print(exception)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.263082Z",
"iopub.status.busy": "2024-08-27T01:20:46.262557Z",
"iopub.status.idle": "2024-08-27T01:20:46.266272Z",
"shell.execute_reply": "2024-08-27T01:20:46.265620Z"
},
"id": "EOWIlVidDl-n"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"all scalar values must have the same nesting depth\n"
]
}
],
"source": [
"try:\n",
" tf.ragged.constant([\"A\", [\"B\", \"C\"]]) # bad: multiple nesting depths\n",
"except ValueError as exception:\n",
" print(exception)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nhHMFhSp-PVq"
},
"source": [
"## Example use case\n",
"\n",
"The following example demonstrates how `RaggedTensor`s can be used to construct and combine unigram and bigram embeddings for a batch of variable-length queries, using special markers for the beginning and end of each sentence. For more details on the ops used in this example, check the `tf.ragged` package documentation."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.269513Z",
"iopub.status.busy": "2024-08-27T01:20:46.269300Z",
"iopub.status.idle": "2024-08-27T01:20:46.385369Z",
"shell.execute_reply": "2024-08-27T01:20:46.384730Z"
},
"id": "ZBs_V7e--PVr"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[-0.07189021 0.23444025 -0.04020268 -0.27113184]\n",
" [ 0.10560822 0.00976487 -0.17885399 0.5371701 ]\n",
" [-0.23479678 -0.15996003 0.07078557 0.24388357]], shape=(3, 4), dtype=float32)\n"
]
}
],
"source": [
"queries = tf.ragged.constant([['Who', 'is', 'Dan', 'Smith'],\n",
" ['Pause'],\n",
" ['Will', 'it', 'rain', 'later', 'today']])\n",
"\n",
"# Create an embedding table.\n",
"num_buckets = 1024\n",
"embedding_size = 4\n",
"embedding_table = tf.Variable(\n",
" tf.random.truncated_normal([num_buckets, embedding_size],\n",
" stddev=1.0 / math.sqrt(embedding_size)))\n",
"\n",
"# Look up the embedding for each word.\n",
"word_buckets = tf.strings.to_hash_bucket_fast(queries, num_buckets)\n",
"word_embeddings = tf.nn.embedding_lookup(embedding_table, word_buckets) # ①\n",
"\n",
"# Add markers to the beginning and end of each sentence.\n",
"marker = tf.fill([queries.nrows(), 1], '#')\n",
"padded = tf.concat([marker, queries, marker], axis=1) # ②\n",
"\n",
"# Build word bigrams and look up embeddings.\n",
"bigrams = tf.strings.join([padded[:, :-1], padded[:, 1:]], separator='+') # ③\n",
"\n",
"bigram_buckets = tf.strings.to_hash_bucket_fast(bigrams, num_buckets)\n",
"bigram_embeddings = tf.nn.embedding_lookup(embedding_table, bigram_buckets) # ④\n",
"\n",
"# Find the average embedding for each sentence\n",
"all_embeddings = tf.concat([word_embeddings, bigram_embeddings], axis=1) # ⑤\n",
"avg_embedding = tf.reduce_mean(all_embeddings, axis=1) # ⑥\n",
"print(avg_embedding)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y_lE_LAVcWQH"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "An_k0pX1-PVt"
},
"source": [
"## Ragged and uniform dimensions\n",
"\n",
"A ***ragged dimension*** is a dimension whose slices may have different lengths. For example, the inner (column) dimension of `rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged, since the column slices (`rt[0, :]`, ..., `rt[4, :]`) have different lengths. Dimensions whose slices all have the same length are called *uniform dimensions*.\n",
"\n",
"The outermost dimension of a ragged tensor is always uniform, since it consists of a single slice (and, therefore, there is no possibility for differing slice lengths). The remaining dimensions may be either ragged or uniform. For example, you may store the word embeddings for each word in a batch of sentences using a ragged tensor with shape `[num_sentences, (num_words), embedding_size]`, where the parentheses around `(num_words)` indicate that the dimension is ragged.\n",
"\n",
"\n",
"\n",
"Ragged tensors may have multiple ragged dimensions. For example, you could store a batch of structured text documents using a tensor with shape `[num_documents, (num_paragraphs), (num_sentences), (num_words)]` (where again parentheses are used to indicate ragged dimensions).\n",
"\n",
"As with `tf.Tensor`, the ***rank*** of a ragged tensor is its total number of dimensions (including both ragged and uniform dimensions). A ***potentially ragged tensor*** is a value that might be either a `tf.Tensor` or a `tf.RaggedTensor`.\n",
"\n",
"When describing the shape of a RaggedTensor, ragged dimensions are conventionally indicated by enclosing them in parentheses. For example, as you saw above, the shape of a 3D RaggedTensor that stores word embeddings for each word in a batch of sentences can be written as `[num_sentences, (num_words), embedding_size]`.\n",
"\n",
"The `RaggedTensor.shape` attribute returns a `tf.TensorShape` for a ragged tensor where ragged dimensions have size `None`:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.388782Z",
"iopub.status.busy": "2024-08-27T01:20:46.388535Z",
"iopub.status.idle": "2024-08-27T01:20:46.393769Z",
"shell.execute_reply": "2024-08-27T01:20:46.393185Z"
},
"id": "M2Wzx4JEIvmb"
},
"outputs": [
{
"data": {
"text/plain": [
"TensorShape([2, None])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf.ragged.constant([[\"Hi\"], [\"How\", \"are\", \"you\"]]).shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G9tfJOeFlijE"
},
"source": [
"The method `tf.RaggedTensor.bounding_shape` can be used to find a tight\n",
"bounding shape for a given `RaggedTensor`:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.396936Z",
"iopub.status.busy": "2024-08-27T01:20:46.396510Z",
"iopub.status.idle": "2024-08-27T01:20:46.404249Z",
"shell.execute_reply": "2024-08-27T01:20:46.403666Z"
},
"id": "5DHaqXHxlWi0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([2 3], shape=(2,), dtype=int64)\n"
]
}
],
"source": [
"print(tf.ragged.constant([[\"Hi\"], [\"How\", \"are\", \"you\"]]).bounding_shape())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V8e7x95UcLS6"
},
"source": [
"## Ragged vs sparse\n",
"\n",
"A ragged tensor should *not* be thought of as a type of sparse tensor. In particular, sparse tensors are *efficient encodings for `tf.Tensor`* that model the same data in a compact format; but ragged tensor is an *extension to `tf.Tensor`* that models an expanded class of data. This difference is crucial when defining operations:\n",
"\n",
"- Applying an op to a sparse or dense tensor should always give the same result.\n",
"- Applying an op to a ragged or sparse tensor may give different results.\n",
"\n",
"As an illustrative example, consider how array operations such as `concat`, `stack`, and `tile` are defined for ragged vs. sparse tensors. Concatenating ragged tensors joins each row to form a single row with the combined length:\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.407323Z",
"iopub.status.busy": "2024-08-27T01:20:46.407069Z",
"iopub.status.idle": "2024-08-27T01:20:46.417971Z",
"shell.execute_reply": "2024-08-27T01:20:46.417386Z"
},
"id": "ush7IGUWLXIn"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"ragged_x = tf.ragged.constant([[\"John\"], [\"a\", \"big\", \"dog\"], [\"my\", \"cat\"]])\n",
"ragged_y = tf.ragged.constant([[\"fell\", \"asleep\"], [\"barked\"], [\"is\", \"fuzzy\"]])\n",
"print(tf.concat([ragged_x, ragged_y], axis=1))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pvQzZG8zMoWa"
},
"source": [
"However, concatenating sparse tensors is equivalent to concatenating the corresponding dense tensors, as illustrated by the following example (where Ø indicates missing values):\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.420899Z",
"iopub.status.busy": "2024-08-27T01:20:46.420684Z",
"iopub.status.idle": "2024-08-27T01:20:46.428992Z",
"shell.execute_reply": "2024-08-27T01:20:46.428429Z"
},
"id": "eTIhGayQL0gI"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[b'John' b'' b'' b'fell' b'asleep']\n",
" [b'a' b'big' b'dog' b'barked' b'']\n",
" [b'my' b'cat' b'' b'is' b'fuzzy']], shape=(3, 5), dtype=string)\n"
]
}
],
"source": [
"sparse_x = ragged_x.to_sparse()\n",
"sparse_y = ragged_y.to_sparse()\n",
"sparse_result = tf.sparse.concat(sp_inputs=[sparse_x, sparse_y], axis=1)\n",
"print(tf.sparse.to_dense(sparse_result, ''))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Vl8eQN8pMuYx"
},
"source": [
"For another example of why this distinction is important, consider the\n",
"definition of “the mean value of each row” for an op such as `tf.reduce_mean`.\n",
"For a ragged tensor, the mean value for a row is the sum of the\n",
"row’s values divided by the row’s width.\n",
"But for a sparse tensor, the mean value for a row is the sum of the\n",
"row’s values divided by the sparse tensor’s overall width (which is\n",
"greater than or equal to the width of the longest row).\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u4yjxcK7IPXc"
},
"source": [
"## TensorFlow APIs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VoZGwFQjIYU5"
},
"source": [
"### Keras\n",
"\n",
"[tf.keras](https://www.tensorflow.org/guide/keras) is TensorFlow's high-level API for building and training deep learning models. It doesn't have ragged support. But it does support masked tensors. So the easiest way to use a ragged tensor in a Keras model is to convert the ragged tensor to a dense tensor, using `.to_tensor()` and then using Keras's builtin masking:"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.432250Z",
"iopub.status.busy": "2024-08-27T01:20:46.431996Z",
"iopub.status.idle": "2024-08-27T01:20:46.436141Z",
"shell.execute_reply": "2024-08-27T01:20:46.435483Z"
},
"id": "ucYf2sSzTvQo"
},
"outputs": [],
"source": [
"# Task: predict whether each sentence is a question or not.\n",
"sentences = tf.constant(\n",
" ['What makes you think she is a witch?',\n",
" 'She turned me into a newt.',\n",
" 'A newt?',\n",
" 'Well, I got better.'])\n",
"is_question = tf.constant([True, False, True, False])"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.439376Z",
"iopub.status.busy": "2024-08-27T01:20:46.438811Z",
"iopub.status.idle": "2024-08-27T01:20:46.450725Z",
"shell.execute_reply": "2024-08-27T01:20:46.450175Z"
},
"id": "MGYKmizJTw8B"
},
"outputs": [
{
"data": {
"text/plain": [
"[[940, 203, 668, 387, 790, 320, 939, 185],\n",
" [315, 515, 791, 181, 939, 787],\n",
" [564, 205],\n",
" [820, 180, 993, 739]]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Preprocess the input strings.\n",
"hash_buckets = 1000\n",
"words = tf.strings.split(sentences, ' ')\n",
"hashed_words = tf.strings.to_hash_bucket_fast(words, hash_buckets)\n",
"hashed_words.to_list()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.453980Z",
"iopub.status.busy": "2024-08-27T01:20:46.453405Z",
"iopub.status.idle": "2024-08-27T01:20:46.459940Z",
"shell.execute_reply": "2024-08-27T01:20:46.459397Z"
},
"id": "7FTujwOlUT8J"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hashed_words.to_tensor()"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.463107Z",
"iopub.status.busy": "2024-08-27T01:20:46.462584Z",
"iopub.status.idle": "2024-08-27T01:20:46.507832Z",
"shell.execute_reply": "2024-08-27T01:20:46.507130Z"
},
"id": "vzWudaESUBOZ"
},
"outputs": [],
"source": [
"tf.keras.Input?"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:46.511233Z",
"iopub.status.busy": "2024-08-27T01:20:46.510674Z",
"iopub.status.idle": "2024-08-27T01:20:50.837694Z",
"shell.execute_reply": "2024-08-27T01:20:50.837003Z"
},
"id": "pHls7hQVJlk5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
"I0000 00:00:1724721648.311913 10235 service.cc:146] XLA service 0x7f82d4008680 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n",
"I0000 00:00:1724721648.311945 10235 service.cc:154] StreamExecutor device (0): Tesla T4, Compute Capability 7.5\n",
"I0000 00:00:1724721648.311949 10235 service.cc:154] StreamExecutor device (1): Tesla T4, Compute Capability 7.5\n",
"I0000 00:00:1724721648.311952 10235 service.cc:154] StreamExecutor device (2): Tesla T4, Compute Capability 7.5\n",
"I0000 00:00:1724721648.311955 10235 service.cc:154] StreamExecutor device (3): Tesla T4, Compute Capability 7.5\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4s/step - loss: 8.0590"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 4s/step - loss: 8.0590\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/5\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 45ms/step - loss: 8.0590"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 47ms/step - loss: 8.0590\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/5\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 42ms/step - loss: 8.0590"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 44ms/step - loss: 8.0590\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/5\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 42ms/step - loss: 8.0590"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 44ms/step - loss: 8.0590\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"I0000 00:00:1724721650.626376 10235 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 41ms/step - loss: 8.0590"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 43ms/step - loss: 8.0590\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Build the Keras model.\n",
"keras_model = tf.keras.Sequential([\n",
" tf.keras.layers.Embedding(hash_buckets, 16, mask_zero=True),\n",
" tf.keras.layers.LSTM(32, return_sequences=True, use_bias=False),\n",
" tf.keras.layers.GlobalAveragePooling1D(),\n",
" tf.keras.layers.Dense(32),\n",
" tf.keras.layers.Activation(tf.nn.relu),\n",
" tf.keras.layers.Dense(1)\n",
"])\n",
"\n",
"keras_model.compile(loss='binary_crossentropy', optimizer='rmsprop')\n",
"keras_model.fit(hashed_words.to_tensor(), is_question, epochs=5)\n"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:50.841068Z",
"iopub.status.busy": "2024-08-27T01:20:50.840631Z",
"iopub.status.idle": "2024-08-27T01:20:51.290503Z",
"shell.execute_reply": "2024-08-27T01:20:51.289852Z"
},
"id": "1IAjjmdTU9OU"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 406ms/step"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 407ms/step\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-0.00517703]\n",
" [-0.00227403]\n",
" [-0.00706224]\n",
" [-0.00354813]]\n"
]
}
],
"source": [
"print(keras_model.predict(hashed_words.to_tensor()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8B_sdlt6Ij61"
},
"source": [
"### tf.Example\n",
"\n",
"[tf.Example](https://www.tensorflow.org/tutorials/load_data/tfrecord) is a standard [protobuf](https://developers.google.com/protocol-buffers/) encoding for TensorFlow data. Data encoded with `tf.Example`s often includes variable-length features. For example, the following code defines a batch of four `tf.Example` messages with different feature lengths:"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:51.294346Z",
"iopub.status.busy": "2024-08-27T01:20:51.293713Z",
"iopub.status.idle": "2024-08-27T01:20:51.300081Z",
"shell.execute_reply": "2024-08-27T01:20:51.299374Z"
},
"id": "xsiglYM7TXGr"
},
"outputs": [],
"source": [
"import google.protobuf.text_format as pbtext\n",
"\n",
"def build_tf_example(s):\n",
" return pbtext.Merge(s, tf.train.Example()).SerializeToString()\n",
"\n",
"example_batch = [\n",
" build_tf_example(r'''\n",
" features {\n",
" feature {key: \"colors\" value {bytes_list {value: [\"red\", \"blue\"]} } }\n",
" feature {key: \"lengths\" value {int64_list {value: [7]} } } }'''),\n",
" build_tf_example(r'''\n",
" features {\n",
" feature {key: \"colors\" value {bytes_list {value: [\"orange\"]} } }\n",
" feature {key: \"lengths\" value {int64_list {value: []} } } }'''),\n",
" build_tf_example(r'''\n",
" features {\n",
" feature {key: \"colors\" value {bytes_list {value: [\"black\", \"yellow\"]} } }\n",
" feature {key: \"lengths\" value {int64_list {value: [1, 3]} } } }'''),\n",
" build_tf_example(r'''\n",
" features {\n",
" feature {key: \"colors\" value {bytes_list {value: [\"green\"]} } }\n",
" feature {key: \"lengths\" value {int64_list {value: [3, 5, 2]} } } }''')]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "szUuXFvtUL2o"
},
"source": [
"You can parse this encoded data using `tf.io.parse_example`, which takes a tensor of serialized strings and a feature specification dictionary, and returns a dictionary mapping feature names to tensors. To read the variable-length features into ragged tensors, you simply use `tf.io.RaggedFeature` in the feature specification dictionary:"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:51.303436Z",
"iopub.status.busy": "2024-08-27T01:20:51.302916Z",
"iopub.status.idle": "2024-08-27T01:20:51.310383Z",
"shell.execute_reply": "2024-08-27T01:20:51.309799Z"
},
"id": "xcdaIbYVT4mo"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"colors=\n",
"lengths=\n"
]
}
],
"source": [
"feature_specification = {\n",
" 'colors': tf.io.RaggedFeature(tf.string),\n",
" 'lengths': tf.io.RaggedFeature(tf.int64),\n",
"}\n",
"feature_tensors = tf.io.parse_example(example_batch, feature_specification)\n",
"for name, value in feature_tensors.items():\n",
" print(\"{}={}\".format(name, value))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IK9X_8rXVr8h"
},
"source": [
"`tf.io.RaggedFeature` can also be used to read features with multiple ragged dimensions. For details, refer to the [API documentation](https://www.tensorflow.org/api_docs/python/tf/io/RaggedFeature)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UJowRhlxIX0R"
},
"source": [
"### Datasets\n",
"\n",
"[tf.data](https://www.tensorflow.org/guide/data) is an API that enables you to build complex input pipelines from simple, reusable pieces. Its core data structure is `tf.data.Dataset`, which represents a sequence of elements, in which each element consists of one or more components."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:51.314004Z",
"iopub.status.busy": "2024-08-27T01:20:51.313475Z",
"iopub.status.idle": "2024-08-27T01:20:51.317520Z",
"shell.execute_reply": "2024-08-27T01:20:51.316919Z"
},
"id": "fBml1m2G2vO9"
},
"outputs": [],
"source": [
"# Helper function used to print datasets in the examples below.\n",
"def print_dictionary_dataset(dataset):\n",
" for i, element in enumerate(dataset):\n",
" print(\"Element {}:\".format(i))\n",
" for (feature_name, feature_value) in element.items():\n",
" print('{:>14} = {}'.format(feature_name, feature_value))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gEu_H1Sp2jz1"
},
"source": [
"#### Building Datasets with ragged tensors\n",
"\n",
"Datasets can be built from ragged tensors using the same methods that are used to build them from `tf.Tensor`s or NumPy `array`s, such as `Dataset.from_tensor_slices`:"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:51.320920Z",
"iopub.status.busy": "2024-08-27T01:20:51.320405Z",
"iopub.status.idle": "2024-08-27T01:20:51.338582Z",
"shell.execute_reply": "2024-08-27T01:20:51.337946Z"
},
"id": "BuelF_y2mEq9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Element 0:\n",
" colors = [b'red' b'blue']\n",
" lengths = [7]\n",
"Element 1:\n",
" colors = [b'orange']\n",
" lengths = []\n",
"Element 2:\n",
" colors = [b'black' b'yellow']\n",
" lengths = [1 3]\n",
"Element 3:\n",
" colors = [b'green']\n",
" lengths = [3 5 2]\n"
]
}
],
"source": [
"dataset = tf.data.Dataset.from_tensor_slices(feature_tensors)\n",
"print_dictionary_dataset(dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mC-QNkJc56De"
},
"source": [
"Note: `Dataset.from_generator` does not support ragged tensors yet, but support will be added soon."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K0UKvBLf1VMu"
},
"source": [
"#### Batching and unbatching Datasets with ragged tensors\n",
"\n",
"Datasets with ragged tensors can be batched (which combines *n* consecutive elements into a single elements) using the `Dataset.batch` method."
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:51.342344Z",
"iopub.status.busy": "2024-08-27T01:20:51.341764Z",
"iopub.status.idle": "2024-08-27T01:20:51.356737Z",
"shell.execute_reply": "2024-08-27T01:20:51.356051Z"
},
"id": "lk62aRz63IZn"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Element 0:\n",
" colors = \n",
" lengths = \n",
"Element 1:\n",
" colors = \n",
" lengths = \n"
]
}
],
"source": [
"batched_dataset = dataset.batch(2)\n",
"print_dictionary_dataset(batched_dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NLSGiYEQ5A8N"
},
"source": [
"Conversely, a batched dataset can be transformed into a flat dataset using `Dataset.unbatch`."
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:51.360162Z",
"iopub.status.busy": "2024-08-27T01:20:51.359538Z",
"iopub.status.idle": "2024-08-27T01:20:51.399944Z",
"shell.execute_reply": "2024-08-27T01:20:51.399254Z"
},
"id": "CxLlaPw_5Je4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Element 0:\n",
" colors = [b'red' b'blue']\n",
" lengths = [7]\n",
"Element 1:\n",
" colors = [b'orange']\n",
" lengths = []\n",
"Element 2:\n",
" colors = [b'black' b'yellow']\n",
" lengths = [1 3]\n",
"Element 3:\n",
" colors = [b'green']\n",
" lengths = [3 5 2]\n"
]
}
],
"source": [
"unbatched_dataset = batched_dataset.unbatch()\n",
"print_dictionary_dataset(unbatched_dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YzpLQFh33q0N"
},
"source": [
"#### Batching Datasets with variable-length non-ragged tensors\n",
"\n",
"If you have a Dataset that contains non-ragged tensors, and tensor lengths vary across elements, then you can batch those non-ragged tensors into ragged tensors by applying the `dense_to_ragged_batch` transformation:"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:51.403229Z",
"iopub.status.busy": "2024-08-27T01:20:51.402817Z",
"iopub.status.idle": "2024-08-27T01:20:51.455934Z",
"shell.execute_reply": "2024-08-27T01:20:51.455180Z"
},
"id": "PYnhERwh3_mf"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /tmpfs/tmp/ipykernel_10064/1427668168.py:4: dense_to_ragged_batch (from tensorflow.python.data.experimental.ops.batching) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.data.Dataset.ragged_batch` instead.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\n"
]
}
],
"source": [
"non_ragged_dataset = tf.data.Dataset.from_tensor_slices([1, 5, 3, 2, 8])\n",
"non_ragged_dataset = non_ragged_dataset.map(tf.range)\n",
"batched_non_ragged_dataset = non_ragged_dataset.apply(\n",
" tf.data.experimental.dense_to_ragged_batch(2))\n",
"for element in batched_non_ragged_dataset:\n",
" print(element)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nXFPeE-CzJ-s"
},
"source": [
"#### Transforming Datasets with ragged tensors\n",
"\n",
"You can also create or transform ragged tensors in Datasets using `Dataset.map`:"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:51.459050Z",
"iopub.status.busy": "2024-08-27T01:20:51.458812Z",
"iopub.status.idle": "2024-08-27T01:20:51.528119Z",
"shell.execute_reply": "2024-08-27T01:20:51.527442Z"
},
"id": "Ios1GuG-pf9U"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Element 0:\n",
" mean_length = 7\n",
" length_ranges = \n",
"Element 1:\n",
" mean_length = 0\n",
" length_ranges = \n",
"Element 2:\n",
" mean_length = 2\n",
" length_ranges = \n",
"Element 3:\n",
" mean_length = 3\n",
" length_ranges = \n"
]
}
],
"source": [
"def transform_lengths(features):\n",
" return {\n",
" 'mean_length': tf.math.reduce_mean(features['lengths']),\n",
" 'length_ranges': tf.ragged.range(features['lengths'])}\n",
"transformed_dataset = dataset.map(transform_lengths)\n",
"print_dictionary_dataset(transformed_dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WD2lWw3fIXrg"
},
"source": [
"### tf.function\n",
"\n",
"[tf.function](https://www.tensorflow.org/guide/function) is a decorator that precomputes TensorFlow graphs for Python functions, which can substantially improve the performance of your TensorFlow code. Ragged tensors can be used transparently with `@tf.function`-decorated functions. For example, the following function works with both ragged and non-ragged tensors:"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:51.531518Z",
"iopub.status.busy": "2024-08-27T01:20:51.530909Z",
"iopub.status.idle": "2024-08-27T01:20:51.534724Z",
"shell.execute_reply": "2024-08-27T01:20:51.534095Z"
},
"id": "PfyxgVaj_8tl"
},
"outputs": [],
"source": [
"@tf.function\n",
"def make_palindrome(x, axis):\n",
" return tf.concat([x, tf.reverse(x, [axis])], axis)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:51.537808Z",
"iopub.status.busy": "2024-08-27T01:20:51.537453Z",
"iopub.status.idle": "2024-08-27T01:20:51.586397Z",
"shell.execute_reply": "2024-08-27T01:20:51.585755Z"
},
"id": "vcZdzvEnDEt0"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"make_palindrome(tf.constant([[1, 2], [3, 4], [5, 6]]), axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:51.590205Z",
"iopub.status.busy": "2024-08-27T01:20:51.589718Z",
"iopub.status.idle": "2024-08-27T01:20:51.683547Z",
"shell.execute_reply": "2024-08-27T01:20:51.682935Z"
},
"id": "4WfCMIgdDMxj"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-08-27 01:20:51.662289: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:933] Skipping loop optimization for Merge node with control input: RaggedConcat/assert_equal_1/Assert/AssertGuard/branch_executed/_9\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"make_palindrome(tf.ragged.constant([[1, 2], [3], [4, 5, 6]]), axis=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X2p69YPOBUz8"
},
"source": [
"If you wish to explicitly specify the `input_signature` for the `tf.function`, then you can do so using `tf.RaggedTensorSpec`."
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:51.687385Z",
"iopub.status.busy": "2024-08-27T01:20:51.686869Z",
"iopub.status.idle": "2024-08-27T01:20:51.840368Z",
"shell.execute_reply": "2024-08-27T01:20:51.839724Z"
},
"id": "k6-hkhdDBk6G"
},
"outputs": [
{
"data": {
"text/plain": [
"(,\n",
" )"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@tf.function(\n",
" input_signature=[tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int32)])\n",
"def max_and_min(rt):\n",
" return (tf.math.reduce_max(rt, axis=-1), tf.math.reduce_min(rt, axis=-1))\n",
"\n",
"max_and_min(tf.ragged.constant([[1, 2], [3], [4, 5, 6]]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fSs-7E0VD85q"
},
"source": [
"#### Concrete functions\n",
"\n",
"[Concrete functions](https://www.tensorflow.org/guide/function#obtaining_concrete_functions) encapsulate individual traced graphs that are built by `tf.function`. Ragged tensors can be used transparently with concrete functions.\n"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:51.844012Z",
"iopub.status.busy": "2024-08-27T01:20:51.843366Z",
"iopub.status.idle": "2024-08-27T01:20:51.883820Z",
"shell.execute_reply": "2024-08-27T01:20:51.883144Z"
},
"id": "yyJeXJ4wFWox"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"@tf.function\n",
"def increment(x):\n",
" return x + 1\n",
"\n",
"rt = tf.ragged.constant([[1, 2], [3], [4, 5, 6]])\n",
"cf = increment.get_concrete_function(rt)\n",
"print(cf(rt))\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iYLyPlatIXhh"
},
"source": [
"### SavedModels\n",
"\n",
"A [SavedModel](https://www.tensorflow.org/guide/saved_model) is a serialized TensorFlow program, including both weights and computation. It can be built from a Keras model or from a custom model. In either case, ragged tensors can be used transparently with the functions and methods defined by a SavedModel.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "98VpBSdOgWqL"
},
"source": [
"#### Example: saving a Keras model"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:51.887248Z",
"iopub.status.busy": "2024-08-27T01:20:51.886743Z",
"iopub.status.idle": "2024-08-27T01:20:54.263975Z",
"shell.execute_reply": "2024-08-27T01:20:54.263179Z"
},
"id": "D-Dg9w7Je5pU"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import tempfile\n",
"\n",
"keras_module_path = tempfile.mkdtemp()\n",
"keras_model.save(keras_module_path+\"/my_model.keras\")\n",
"\n",
"imported_model = tf.keras.models.load_model(keras_module_path+\"/my_model.keras\")\n",
"\n",
"imported_model(hashed_words.to_tensor())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9-7k-E92gaoR"
},
"source": [
"#### Example: saving a custom model\n"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.267570Z",
"iopub.status.busy": "2024-08-27T01:20:54.267259Z",
"iopub.status.idle": "2024-08-27T01:20:54.403650Z",
"shell.execute_reply": "2024-08-27T01:20:54.403004Z"
},
"id": "Sfem1ESrdGzX"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpkmr703wo/assets\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpkmr703wo/assets\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class CustomModule(tf.Module):\n",
" def __init__(self, variable_value):\n",
" super(CustomModule, self).__init__()\n",
" self.v = tf.Variable(variable_value)\n",
"\n",
" @tf.function\n",
" def grow(self, x):\n",
" return x * self.v\n",
"\n",
"module = CustomModule(100.0)\n",
"\n",
"# Before saving a custom model, you must ensure that concrete functions are\n",
"# built for each input signature that you will need.\n",
"module.grow.get_concrete_function(tf.RaggedTensorSpec(shape=[None, None],\n",
" dtype=tf.float32))\n",
"\n",
"custom_module_path = tempfile.mkdtemp()\n",
"tf.saved_model.save(module, custom_module_path)\n",
"imported_model = tf.saved_model.load(custom_module_path)\n",
"imported_model.grow(tf.ragged.constant([[1.0, 4.0, 3.0], [2.0]]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SAxis5KBhrBN"
},
"source": [
"Note: SavedModel [signatures](https://www.tensorflow.org/guide/saved_model#specifying_signatures_during_export) are concrete functions. As discussed in the section on Concrete Functions above, ragged tensors are only handled correctly by concrete functions starting with TensorFlow 2.3. If you need to use SavedModel signatures in a previous version of TensorFlow, then it's recommended that you decompose the ragged tensor into its component tensors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cRcHzS6pcHYC"
},
"source": [
"## Overloaded operators\n",
"\n",
"The `RaggedTensor` class overloads the standard Python arithmetic and comparison operators, making it easy to perform basic elementwise math:"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.407287Z",
"iopub.status.busy": "2024-08-27T01:20:54.406798Z",
"iopub.status.idle": "2024-08-27T01:20:54.435410Z",
"shell.execute_reply": "2024-08-27T01:20:54.434755Z"
},
"id": "skScd37P-PVu"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"x = tf.ragged.constant([[1, 2], [3], [4, 5, 6]])\n",
"y = tf.ragged.constant([[1, 1], [2], [3, 3, 3]])\n",
"print(x + y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XEGgbZHV-PVw"
},
"source": [
"Since the overloaded operators perform elementwise computations, the inputs to all binary operations must have the same shape or be broadcastable to the same shape. In the simplest broadcasting case, a single scalar is combined elementwise with each value in a ragged tensor:"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.438638Z",
"iopub.status.busy": "2024-08-27T01:20:54.438414Z",
"iopub.status.idle": "2024-08-27T01:20:54.443187Z",
"shell.execute_reply": "2024-08-27T01:20:54.442592Z"
},
"id": "IYybEEWc-PVx"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"x = tf.ragged.constant([[1, 2], [3], [4, 5, 6]])\n",
"print(x + 3)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "okGb9dIi-PVz"
},
"source": [
"For a discussion of more advanced cases, check the section on **Broadcasting**.\n",
"\n",
"Ragged tensors overload the same set of operators as normal `Tensor`s: the unary operators `-`, `~`, and `abs()`; and the binary operators `+`, `-`, `*`, `/`, `//`, `%`, `**`, `&`, `|`, `^`, `==`, `<`, `<=`, `>`, and `>=`.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f2anbs6ZnFtl"
},
"source": [
"## Indexing\n",
"\n",
"Ragged tensors support Python-style indexing, including multidimensional indexing and slicing. The following examples demonstrate ragged tensor indexing with a 2D and a 3D ragged tensor."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XuEwmC3t_ITL"
},
"source": [
"### Indexing examples: 2D ragged tensor"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.446470Z",
"iopub.status.busy": "2024-08-27T01:20:54.446209Z",
"iopub.status.idle": "2024-08-27T01:20:54.450312Z",
"shell.execute_reply": "2024-08-27T01:20:54.449732Z"
},
"id": "MbSRZRDz-PV1"
},
"outputs": [],
"source": [
"queries = tf.ragged.constant(\n",
" [['Who', 'is', 'George', 'Washington'],\n",
" ['What', 'is', 'the', 'weather', 'tomorrow'],\n",
" ['Goodnight']])"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.453592Z",
"iopub.status.busy": "2024-08-27T01:20:54.453080Z",
"iopub.status.idle": "2024-08-27T01:20:54.460826Z",
"shell.execute_reply": "2024-08-27T01:20:54.460267Z"
},
"id": "2HRs2xhh-vZE"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([b'What' b'is' b'the' b'weather' b'tomorrow'], shape=(5,), dtype=string)\n"
]
}
],
"source": [
"print(queries[1]) # A single query"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.463827Z",
"iopub.status.busy": "2024-08-27T01:20:54.463595Z",
"iopub.status.idle": "2024-08-27T01:20:54.470193Z",
"shell.execute_reply": "2024-08-27T01:20:54.469567Z"
},
"id": "EFfjZV7YA3UH"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(b'the', shape=(), dtype=string)\n"
]
}
],
"source": [
"print(queries[1, 2]) # A single word"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.473236Z",
"iopub.status.busy": "2024-08-27T01:20:54.473013Z",
"iopub.status.idle": "2024-08-27T01:20:54.480752Z",
"shell.execute_reply": "2024-08-27T01:20:54.480121Z"
},
"id": "VISRPQSdA3xn"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(queries[1:]) # Everything but the first row"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.483568Z",
"iopub.status.busy": "2024-08-27T01:20:54.483336Z",
"iopub.status.idle": "2024-08-27T01:20:54.489169Z",
"shell.execute_reply": "2024-08-27T01:20:54.488534Z"
},
"id": "J1PpSyKQBMng"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(queries[:, :3]) # The first 3 words of each query"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.492259Z",
"iopub.status.busy": "2024-08-27T01:20:54.491742Z",
"iopub.status.idle": "2024-08-27T01:20:54.497504Z",
"shell.execute_reply": "2024-08-27T01:20:54.496587Z"
},
"id": "ixrhHmJBeidy"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(queries[:, -2:]) # The last 2 words of each query"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cnOP6Vza-PV4"
},
"source": [
"### Indexing examples: 3D ragged tensor"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.500461Z",
"iopub.status.busy": "2024-08-27T01:20:54.500248Z",
"iopub.status.idle": "2024-08-27T01:20:54.504410Z",
"shell.execute_reply": "2024-08-27T01:20:54.503898Z"
},
"id": "8VbqbKcE-PV6"
},
"outputs": [],
"source": [
"rt = tf.ragged.constant([[[1, 2, 3], [4]],\n",
" [[5], [], [6]],\n",
" [[7]],\n",
" [[8, 9], [10]]])"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.507146Z",
"iopub.status.busy": "2024-08-27T01:20:54.506880Z",
"iopub.status.idle": "2024-08-27T01:20:54.517323Z",
"shell.execute_reply": "2024-08-27T01:20:54.516677Z"
},
"id": "f9WPVWf4grVp"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(rt[1]) # Second row (2D RaggedTensor)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.520107Z",
"iopub.status.busy": "2024-08-27T01:20:54.519883Z",
"iopub.status.idle": "2024-08-27T01:20:54.530882Z",
"shell.execute_reply": "2024-08-27T01:20:54.530305Z"
},
"id": "ad8FGJoABjQH"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([8 9], shape=(2,), dtype=int32)\n"
]
}
],
"source": [
"print(rt[3, 0]) # First element of fourth row (1D Tensor)"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.533730Z",
"iopub.status.busy": "2024-08-27T01:20:54.533492Z",
"iopub.status.idle": "2024-08-27T01:20:54.540792Z",
"shell.execute_reply": "2024-08-27T01:20:54.540171Z"
},
"id": "MPPr-a-bBjFE"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(rt[:, 1:3]) # Items 1-3 of each row (3D RaggedTensor)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.544188Z",
"iopub.status.busy": "2024-08-27T01:20:54.543610Z",
"iopub.status.idle": "2024-08-27T01:20:54.549368Z",
"shell.execute_reply": "2024-08-27T01:20:54.548771Z"
},
"id": "6SIDeoIUBi4z"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(rt[:, -1:]) # Last item of each row (3D RaggedTensor)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_d3nBh1GnWvU"
},
"source": [
"`RaggedTensor`s support multidimensional indexing and slicing with one restriction: indexing into a ragged dimension is not allowed. This case is problematic because the indicated value may exist in some rows but not others. In such cases, it's not obvious whether you should (1) raise an `IndexError`; (2) use a default value; or (3) skip that value and return a tensor with fewer rows than you started with. Following the [guiding principles of Python](https://www.python.org/dev/peps/pep-0020/) (\"In the face of ambiguity, refuse the temptation to guess\"), this operation is currently disallowed."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IsWKETULAJbN"
},
"source": [
"## Tensor type conversion\n",
"\n",
"The `RaggedTensor` class defines methods that can be used to convert\n",
"between `RaggedTensor`s and `tf.Tensor`s or `tf.SparseTensors`:"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.552476Z",
"iopub.status.busy": "2024-08-27T01:20:54.552240Z",
"iopub.status.idle": "2024-08-27T01:20:54.555992Z",
"shell.execute_reply": "2024-08-27T01:20:54.555394Z"
},
"id": "INnfmZGcBoU_"
},
"outputs": [],
"source": [
"ragged_sentences = tf.ragged.constant([\n",
" ['Hi'], ['Welcome', 'to', 'the', 'fair'], ['Have', 'fun']])"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.558928Z",
"iopub.status.busy": "2024-08-27T01:20:54.558681Z",
"iopub.status.idle": "2024-08-27T01:20:54.563651Z",
"shell.execute_reply": "2024-08-27T01:20:54.563048Z"
},
"id": "__iJ4iXtkGOx"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[[b'Hi' b'' b'' b'' b'' b'' b'' b'' b'' b'']\n",
" [b'Welcome' b'to' b'the' b'fair' b'' b'' b'' b'' b'' b'']\n",
" [b'Have' b'fun' b'' b'' b'' b'' b'' b'' b'' b'']], shape=(3, 10), dtype=string)\n"
]
}
],
"source": [
"# RaggedTensor -> Tensor\n",
"print(ragged_sentences.to_tensor(default_value='', shape=[None, 10]))"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.566370Z",
"iopub.status.busy": "2024-08-27T01:20:54.566131Z",
"iopub.status.idle": "2024-08-27T01:20:54.637678Z",
"shell.execute_reply": "2024-08-27T01:20:54.636937Z"
},
"id": "-rfiyYqne8QN"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Tensor -> RaggedTensor\n",
"x = [[1, 3, -1, -1], [2, -1, -1, -1], [4, 5, 8, 9]]\n",
"print(tf.RaggedTensor.from_tensor(x, padding=-1))"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.641025Z",
"iopub.status.busy": "2024-08-27T01:20:54.640781Z",
"iopub.status.idle": "2024-08-27T01:20:54.644936Z",
"shell.execute_reply": "2024-08-27T01:20:54.644295Z"
},
"id": "41WAZLXNnbwH"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SparseTensor(indices=tf.Tensor(\n",
"[[0 0]\n",
" [1 0]\n",
" [1 1]\n",
" [1 2]\n",
" [1 3]\n",
" [2 0]\n",
" [2 1]], shape=(7, 2), dtype=int64), values=tf.Tensor([b'Hi' b'Welcome' b'to' b'the' b'fair' b'Have' b'fun'], shape=(7,), dtype=string), dense_shape=tf.Tensor([3 4], shape=(2,), dtype=int64))\n"
]
}
],
"source": [
"#RaggedTensor -> SparseTensor\n",
"print(ragged_sentences.to_sparse())"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.648024Z",
"iopub.status.busy": "2024-08-27T01:20:54.647657Z",
"iopub.status.idle": "2024-08-27T01:20:54.665751Z",
"shell.execute_reply": "2024-08-27T01:20:54.665145Z"
},
"id": "S8MkYo2hfVhj"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# SparseTensor -> RaggedTensor\n",
"st = tf.SparseTensor(indices=[[0, 0], [2, 0], [2, 1]],\n",
" values=['a', 'b', 'c'],\n",
" dense_shape=[3, 3])\n",
"print(tf.RaggedTensor.from_sparse(st))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qx025sNMkAHH"
},
"source": [
"## Evaluating ragged tensors\n",
"\n",
"To access the values in a ragged tensor, you can:\n",
"\n",
"1. Use `tf.RaggedTensor.to_list` to convert the ragged tensor to a nested Python list.\n",
"2. Use `tf.RaggedTensor.numpy` to convert the ragged tensor to a NumPy array whose values are nested NumPy arrays.\n",
"3. Decompose the ragged tensor into its components, using the `tf.RaggedTensor.values` and `tf.RaggedTensor.row_splits` properties, or row-partitioning methods such as `tf.RaggedTensor.row_lengths` and `tf.RaggedTensor.value_rowids`.\n",
"4. Use Python indexing to select values from the ragged tensor.\n"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.669161Z",
"iopub.status.busy": "2024-08-27T01:20:54.668737Z",
"iopub.status.idle": "2024-08-27T01:20:54.677810Z",
"shell.execute_reply": "2024-08-27T01:20:54.677184Z"
},
"id": "uMm1WMkc-PV_"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Python list: [[1, 2], [3, 4, 5], [6], [], [7]]\n",
"NumPy array: [array([1, 2], dtype=int32) array([3, 4, 5], dtype=int32)\n",
" array([6], dtype=int32) array([], dtype=int32) array([7], dtype=int32)]\n",
"Values: [1 2 3 4 5 6 7]\n",
"Splits: [0 2 5 6 6 7]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Indexed value: [3 4 5]\n"
]
}
],
"source": [
"rt = tf.ragged.constant([[1, 2], [3, 4, 5], [6], [], [7]])\n",
"print(\"Python list:\", rt.to_list())\n",
"print(\"NumPy array:\", rt.numpy())\n",
"print(\"Values:\", rt.values.numpy())\n",
"print(\"Splits:\", rt.row_splits.numpy())\n",
"print(\"Indexed value:\", rt[1].numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "J87jMZa0M_YW"
},
"source": [
"## Ragged Shapes\n",
"\n",
"The shape of a tensor specifies the size of each axis. For example, the shape of `[[1, 2], [3, 4], [5, 6]]` is `[3, 2]`, since there are 3 rows and 2 columns. TensorFlow has two separate but related ways to describe shapes:\n",
"\n",
"* ***static shape***: Information about axis sizes that is known statically (e.g., while tracing a `tf.function`). May be partially specified.\n",
"\n",
"* ***dynamic shape***: Runtime information about the axis sizes."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IOETE_OLPLZo"
},
"source": [
"### Static shape\n",
"\n",
"A Tensor's static shape contains information about its axis sizes that is known at graph-construction time. For both `tf.Tensor` and `tf.RaggedTensor`, it is available using the `.shape` property, and is encoded using `tf.TensorShape`:"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.681103Z",
"iopub.status.busy": "2024-08-27T01:20:54.680724Z",
"iopub.status.idle": "2024-08-27T01:20:54.685197Z",
"shell.execute_reply": "2024-08-27T01:20:54.684593Z"
},
"id": "btGDjT4uNgQy"
},
"outputs": [
{
"data": {
"text/plain": [
"TensorShape([3, 2])"
]
},
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = tf.constant([[1, 2], [3, 4], [5, 6]])\n",
"x.shape # shape of a tf.tensor"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.688094Z",
"iopub.status.busy": "2024-08-27T01:20:54.687856Z",
"iopub.status.idle": "2024-08-27T01:20:54.692869Z",
"shell.execute_reply": "2024-08-27T01:20:54.692266Z"
},
"id": "__OgvmrGPEjq"
},
"outputs": [
{
"data": {
"text/plain": [
"TensorShape([4, None])"
]
},
"execution_count": 67,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rt = tf.ragged.constant([[1], [2, 3], [], [4]])\n",
"rt.shape # shape of a tf.RaggedTensor"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9EWnQd3qPWaw"
},
"source": [
"The static shape of a ragged dimension is always `None` (i.e., unspecified). However, the inverse is not true -- if a `TensorShape` dimension is `None`, then that could indicate that the dimension is ragged, *or* it could indicate that the dimension is uniform but that its size is not statically known."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "75E9YXYMNfne"
},
"source": [
"### Dynamic shape\n",
"\n",
"A tensor's dynamic shape contains information about its axis sizes that is known when the graph is run. It is constructed using the `tf.shape` operation. For `tf.Tensor`, `tf.shape` returns the shape as a 1D integer `Tensor`, where `tf.shape(x)[i]` is the size of axis `i`."
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.696090Z",
"iopub.status.busy": "2024-08-27T01:20:54.695859Z",
"iopub.status.idle": "2024-08-27T01:20:54.701475Z",
"shell.execute_reply": "2024-08-27T01:20:54.700822Z"
},
"id": "kWJ7Cn1EQTD_"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = tf.constant([['a', 'b'], ['c', 'd'], ['e', 'f']])\n",
"tf.shape(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BeZEfxwmRcSv"
},
"source": [
"However, a 1D `Tensor` is not expressive enough to describe the shape of a `tf.RaggedTensor`. Instead, the dynamic shape for ragged tensors is encoded using a dedicated type, `tf.experimental.DynamicRaggedShape`. In the following example, the `DynamicRaggedShape` returned by `tf.shape(rt)` indicates that the ragged tensor has 4 rows, with lengths 1, 3, 0, and 2:"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.704765Z",
"iopub.status.busy": "2024-08-27T01:20:54.704243Z",
"iopub.status.idle": "2024-08-27T01:20:54.710228Z",
"shell.execute_reply": "2024-08-27T01:20:54.709615Z"
},
"id": "nZc2wqgQQUFU"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"rt = tf.ragged.constant([[1], [2, 3, 4], [], [5, 6]])\n",
"rt_shape = tf.shape(rt)\n",
"print(rt_shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EphU60YvTf98"
},
"source": [
"#### Dynamic shape: operations\n",
"\n",
"`DynamicRaggedShape`s can be used with most TensorFlow ops that expect shapes, including `tf.reshape`, `tf.zeros`, `tf.ones`. `tf.fill`, `tf.broadcast_dynamic_shape`, and `tf.broadcast_to`."
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.713728Z",
"iopub.status.busy": "2024-08-27T01:20:54.713146Z",
"iopub.status.idle": "2024-08-27T01:20:54.720911Z",
"shell.execute_reply": "2024-08-27T01:20:54.720305Z"
},
"id": "pclAODLXT6Gr"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.reshape(x, rt_shape) = \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.zeros(rt_shape) = \n",
"tf.ones(rt_shape) = \n",
"tf.fill(rt_shape, 9) = \n"
]
}
],
"source": [
"print(f\"tf.reshape(x, rt_shape) = {tf.reshape(x, rt_shape)}\")\n",
"print(f\"tf.zeros(rt_shape) = {tf.zeros(rt_shape)}\")\n",
"print(f\"tf.ones(rt_shape) = {tf.ones(rt_shape)}\")\n",
"print(f\"tf.fill(rt_shape, 9) = {tf.fill(rt_shape, 'x')}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rNP_3_btRAHj"
},
"source": [
"#### Dynamic shape: indexing and slicing\n",
"\n",
"`DynamicRaggedShape` can be also be indexed to get the sizes of uniform dimensions. For example, we can find the number of rows in a raggedtensor using `tf.shape(rt)[0]` (just as we would for a non-ragged tensor):"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.724127Z",
"iopub.status.busy": "2024-08-27T01:20:54.723890Z",
"iopub.status.idle": "2024-08-27T01:20:54.728409Z",
"shell.execute_reply": "2024-08-27T01:20:54.727688Z"
},
"id": "MzQvPhsxS6HN"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rt_shape[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wvr2iT6zS_e8"
},
"source": [
"However, it is an error to use indexing to try to retrieve the size of a ragged dimension, since it doesn't have a single size. (Since `RaggedTensor` keeps track of which axes are ragged, this error is only thrown during eager execution or when tracing a `tf.function`; it will never be thrown when executing a concrete function.)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.731741Z",
"iopub.status.busy": "2024-08-27T01:20:54.731217Z",
"iopub.status.idle": "2024-08-27T01:20:54.735031Z",
"shell.execute_reply": "2024-08-27T01:20:54.734358Z"
},
"id": "HgGMk0LeTGik"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Got expected ValueError: Index 1 is not uniform\n"
]
}
],
"source": [
"try:\n",
" rt_shape[1]\n",
"except ValueError as e:\n",
" print(\"Got expected ValueError:\", e)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5QUsdawGU0SM"
},
"source": [
"`DynamicRaggedShape`s can also be sliced, as long as the slice either begins with axis `0`, or contains only dense dimensions."
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.738772Z",
"iopub.status.busy": "2024-08-27T01:20:54.738206Z",
"iopub.status.idle": "2024-08-27T01:20:54.743128Z",
"shell.execute_reply": "2024-08-27T01:20:54.742553Z"
},
"id": "APT72EaBU70t"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rt_shape[:1]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a-Wl9IrQXcdY"
},
"source": [
"#### Dynamic shape: encoding\n",
"\n",
"`DynamicRaggedShape` is encoded using two fields:\n",
"\n",
"* `inner_shape`: An integer vector giving the shape of a dense `tf.Tensor`.\n",
"* `row_partitions`: A list of `tf.experimental.RowPartition` objects, describing how the outermost dimension of that inner shape should be partitioned to add ragged axes.\n",
"\n",
"For more information about row partitions, see the \"RaggedTensor encoding\" section below, and the API docs for `tf.experimental.RowPartition`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jfeY9tTcV_zL"
},
"source": [
"#### Dynamic shape: construction\n",
"\n",
"`DynamicRaggedShape` is most often constructed by applying `tf.shape` to a `RaggedTensor`, but it can also be constructed directly:"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.746496Z",
"iopub.status.busy": "2024-08-27T01:20:54.746028Z",
"iopub.status.idle": "2024-08-27T01:20:54.753138Z",
"shell.execute_reply": "2024-08-27T01:20:54.752558Z"
},
"id": "NSRgD667WwIZ"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf.experimental.DynamicRaggedShape(\n",
" row_partitions=[tf.experimental.RowPartition.from_row_lengths([5, 3, 2])],\n",
" inner_shape=[10, 8])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EjzVjs9MXIIA"
},
"source": [
"If the lengths of all rows are known statically, `DynamicRaggedShape.from_lengths` can also be used to construct a dynamic ragged shape. (This is mostly useful for testing and demonstration code, since it's rare for the lengths of ragged dimensions to be known statically).\n"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.756370Z",
"iopub.status.busy": "2024-08-27T01:20:54.755911Z",
"iopub.status.idle": "2024-08-27T01:20:54.762631Z",
"shell.execute_reply": "2024-08-27T01:20:54.762022Z"
},
"id": "gMxCzADUYIjY"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 75,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf.experimental.DynamicRaggedShape.from_lengths([4, (2, 1, 0, 8), 12])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EdljbNPq-PWS"
},
"source": [
"### Broadcasting\n",
"\n",
"Broadcasting is the process of making tensors with different shapes have compatible shapes for elementwise operations. For more background on broadcasting, refer to:\n",
"\n",
"- [NumPy: Broadcasting](https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)\n",
"- `tf.broadcast_dynamic_shape`\n",
"- `tf.broadcast_to`\n",
"\n",
"The basic steps for broadcasting two inputs `x` and `y` to have compatible shapes are:\n",
"\n",
"1. If `x` and `y` do not have the same number of dimensions, then add outer dimensions (with size 1) until they do.\n",
"\n",
"2. For each dimension where `x` and `y` have different sizes:\n",
"\n",
"- If `x` or `y` have size `1` in dimension `d`, then repeat its values across dimension `d` to match the other input's size.\n",
"- Otherwise, raise an exception (`x` and `y` are not broadcast compatible).\n",
"\n",
"Where the size of a tensor in a uniform dimension is a single number (the size of slices across that dimension); and the size of a tensor in a ragged dimension is a list of slice lengths (for all slices across that dimension)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-S2hOUWx-PWU"
},
"source": [
"#### Broadcasting examples"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.766113Z",
"iopub.status.busy": "2024-08-27T01:20:54.765558Z",
"iopub.status.idle": "2024-08-27T01:20:54.770580Z",
"shell.execute_reply": "2024-08-27T01:20:54.769866Z"
},
"id": "0n095XdR-PWU"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# x (2D ragged): 2 x (num_rows)\n",
"# y (scalar)\n",
"# result (2D ragged): 2 x (num_rows)\n",
"x = tf.ragged.constant([[1, 2], [3]])\n",
"y = 3\n",
"print(x + y)"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.774027Z",
"iopub.status.busy": "2024-08-27T01:20:54.773442Z",
"iopub.status.idle": "2024-08-27T01:20:54.804767Z",
"shell.execute_reply": "2024-08-27T01:20:54.804143Z"
},
"id": "0SVYk5AP-PWW"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# x (2d ragged): 3 x (num_rows)\n",
"# y (2d tensor): 3 x 1\n",
"# Result (2d ragged): 3 x (num_rows)\n",
"x = tf.ragged.constant(\n",
" [[10, 87, 12],\n",
" [19, 53],\n",
" [12, 32]])\n",
"y = [[1000], [2000], [3000]]\n",
"print(x + y)"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.808047Z",
"iopub.status.busy": "2024-08-27T01:20:54.807466Z",
"iopub.status.idle": "2024-08-27T01:20:54.860314Z",
"shell.execute_reply": "2024-08-27T01:20:54.859672Z"
},
"id": "MsfBMD80s8Ux"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# x (3d ragged): 2 x (r1) x 2\n",
"# y (2d ragged): 1 x 1\n",
"# Result (3d ragged): 2 x (r1) x 2\n",
"x = tf.ragged.constant(\n",
" [[[1, 2], [3, 4], [5, 6]],\n",
" [[7, 8]]],\n",
" ragged_rank=1)\n",
"y = tf.constant([[10]])\n",
"print(x + y)"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.863714Z",
"iopub.status.busy": "2024-08-27T01:20:54.863147Z",
"iopub.status.idle": "2024-08-27T01:20:54.869603Z",
"shell.execute_reply": "2024-08-27T01:20:54.868988Z"
},
"id": "rEj5QVfnva0t"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# x (3d ragged): 2 x (r1) x (r2) x 1\n",
"# y (1d tensor): 3\n",
"# Result (3d ragged): 2 x (r1) x (r2) x 3\n",
"x = tf.ragged.constant(\n",
" [\n",
" [\n",
" [[1], [2]],\n",
" [],\n",
" [[3]],\n",
" [[4]],\n",
" ],\n",
" [\n",
" [[5], [6]],\n",
" [[7]]\n",
" ]\n",
" ],\n",
" ragged_rank=2)\n",
"y = tf.constant([10, 20, 30])\n",
"print(x + y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uennZ64Aqftb"
},
"source": [
"Here are some examples of shapes that do not broadcast:"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.873000Z",
"iopub.status.busy": "2024-08-27T01:20:54.872414Z",
"iopub.status.idle": "2024-08-27T01:20:54.906714Z",
"shell.execute_reply": "2024-08-27T01:20:54.906054Z"
},
"id": "UpI0FlfL4Eim"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Condition x == y did not hold.\n",
"Indices of first 3 different values:\n",
"[[1]\n",
" [2]\n",
" [3]]\n",
"Corresponding x values:\n",
"[ 4 8 12]\n",
"Corresponding y values:\n",
"[2 6 7]\n",
"First 3 elements of x:\n",
"[0 4 8]\n",
"First 3 elements of y:\n",
"[0 2 6]\n"
]
}
],
"source": [
"# x (2d ragged): 3 x (r1)\n",
"# y (2d tensor): 3 x 4 # trailing dimensions do not match\n",
"x = tf.ragged.constant([[1, 2], [3, 4, 5, 6], [7]])\n",
"y = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])\n",
"try:\n",
" x + y\n",
"except tf.errors.InvalidArgumentError as exception:\n",
" print(exception)"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.909603Z",
"iopub.status.busy": "2024-08-27T01:20:54.909390Z",
"iopub.status.idle": "2024-08-27T01:20:54.941291Z",
"shell.execute_reply": "2024-08-27T01:20:54.940534Z"
},
"id": "qGq1zOT4zMoc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Condition x == y did not hold.\n",
"Indices of first 2 different values:\n",
"[[1]\n",
" [3]]\n",
"Corresponding x values:\n",
"[3 6]\n",
"Corresponding y values:\n",
"[2 5]\n",
"First 3 elements of x:\n",
"[0 3 4]\n",
"First 3 elements of y:\n",
"[0 2 4]\n"
]
}
],
"source": [
"# x (2d ragged): 3 x (r1)\n",
"# y (2d ragged): 3 x (r2) # ragged dimensions do not match.\n",
"x = tf.ragged.constant([[1, 2, 3], [4], [5, 6]])\n",
"y = tf.ragged.constant([[10, 20], [30, 40], [50]])\n",
"try:\n",
" x + y\n",
"except tf.errors.InvalidArgumentError as exception:\n",
" print(exception)"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:54.944485Z",
"iopub.status.busy": "2024-08-27T01:20:54.944024Z",
"iopub.status.idle": "2024-08-27T01:20:54.997055Z",
"shell.execute_reply": "2024-08-27T01:20:54.996380Z"
},
"id": "CvLae5vMqeji"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Condition x == y did not hold.\n",
"Indices of first 3 different values:\n",
"[[1]\n",
" [2]\n",
" [3]]\n",
"Corresponding x values:\n",
"[2 4 6]\n",
"Corresponding y values:\n",
"[3 6 9]\n",
"First 3 elements of x:\n",
"[0 2 4]\n",
"First 3 elements of y:\n",
"[0 3 6]\n"
]
}
],
"source": [
"# x (3d ragged): 3 x (r1) x 2\n",
"# y (3d ragged): 3 x (r1) x 3 # trailing dimensions do not match\n",
"x = tf.ragged.constant([[[1, 2], [3, 4], [5, 6]],\n",
" [[7, 8], [9, 10]]])\n",
"y = tf.ragged.constant([[[1, 2, 0], [3, 4, 0], [5, 6, 0]],\n",
" [[7, 8, 0], [9, 10, 0]]])\n",
"try:\n",
" x + y\n",
"except tf.errors.InvalidArgumentError as exception:\n",
" print(exception)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m0wQkLfV-PWa"
},
"source": [
"## RaggedTensor encoding\n",
"\n",
"Ragged tensors are encoded using the `RaggedTensor` class. Internally, each `RaggedTensor` consists of:\n",
"\n",
"- A `values` tensor, which concatenates the variable-length rows into a flattened list.\n",
"- A `row_partition`, which indicates how those flattened values are divided into rows.\n",
"\n",
"\n",
"\n",
"The `row_partition` can be stored using four different encodings:\n",
"\n",
"- `row_splits` is an integer vector specifying the split points between rows.\n",
"- `value_rowids` is an integer vector specifying the row index for each value.\n",
"- `row_lengths` is an integer vector specifying the length of each row.\n",
"- `uniform_row_length` is an integer scalar specifying a single length for all rows.\n",
"\n",
"\n",
"\n",
"An integer scalar `nrows` can also be included in the `row_partition` encoding to account for empty trailing rows with `value_rowids` or empty rows with `uniform_row_length`.\n"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:55.000256Z",
"iopub.status.busy": "2024-08-27T01:20:55.000001Z",
"iopub.status.idle": "2024-08-27T01:20:55.008808Z",
"shell.execute_reply": "2024-08-27T01:20:55.008232Z"
},
"id": "MrLgMu0gPuo-"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"rt = tf.RaggedTensor.from_row_splits(\n",
" values=[3, 1, 4, 1, 5, 9, 2],\n",
" row_splits=[0, 4, 4, 6, 7])\n",
"print(rt)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wEfZOKwN1Ra_"
},
"source": [
"The choice of which encoding to use for row partitions is managed internally by ragged tensors to improve efficiency in some contexts. In particular, some of the advantages and disadvantages of the different row-partitioning schemes are:\n",
"\n",
"- **Efficient indexing**: The `row_splits` encoding enables constant-time indexing and slicing into ragged tensors.\n",
"- **Efficient concatenation**: The `row_lengths` encoding is more efficient when concatenating ragged tensors, since row lengths do not change when two tensors are concatenated together.\n",
"- **Small encoding size**: The `value_rowids` encoding is more efficient when storing ragged tensors that have a large number of empty rows, since the size of the tensor depends only on the total number of values. On the other hand, the `row_splits` and `row_lengths` encodings are more efficient when storing ragged tensors with longer rows, since they require only one scalar value for each row.\n",
"- **Compatibility**: The `value_rowids` scheme matches the [segmentation](https://www.tensorflow.org/api_docs/python/tf/math#about_segmentation) format used by operations, such as `tf.segment_sum`. The `row_limits` scheme matches the format used by ops such as `tf.sequence_mask`.\n",
"- **Uniform dimensions**: As discussed below, the `uniform_row_length` encoding is used to encode ragged tensors with uniform dimensions."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bpB7xKoUPtU6"
},
"source": [
"### Multiple ragged dimensions\n",
"\n",
"A ragged tensor with multiple ragged dimensions is encoded by using a nested `RaggedTensor` for the `values` tensor. Each nested `RaggedTensor` adds a single ragged dimension.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:55.012019Z",
"iopub.status.busy": "2024-08-27T01:20:55.011654Z",
"iopub.status.idle": "2024-08-27T01:20:55.024270Z",
"shell.execute_reply": "2024-08-27T01:20:55.023681Z"
},
"id": "yy3IGT2a-PWb"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Shape: (3, None, None)\n",
"Number of partitioned dimensions: 2\n"
]
}
],
"source": [
"rt = tf.RaggedTensor.from_row_splits(\n",
" values=tf.RaggedTensor.from_row_splits(\n",
" values=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],\n",
" row_splits=[0, 3, 3, 5, 9, 10]),\n",
" row_splits=[0, 1, 1, 5])\n",
"print(rt)\n",
"print(\"Shape: {}\".format(rt.shape))\n",
"print(\"Number of partitioned dimensions: {}\".format(rt.ragged_rank))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5HqEEDzk-PWc"
},
"source": [
"The factory function `tf.RaggedTensor.from_nested_row_splits` may be used to construct a RaggedTensor with multiple ragged dimensions directly by providing a list of `row_splits` tensors:"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:55.027591Z",
"iopub.status.busy": "2024-08-27T01:20:55.026997Z",
"iopub.status.idle": "2024-08-27T01:20:55.039131Z",
"shell.execute_reply": "2024-08-27T01:20:55.038556Z"
},
"id": "AKYhtFcT-PWd"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"rt = tf.RaggedTensor.from_nested_row_splits(\n",
" flat_values=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],\n",
" nested_row_splits=([0, 1, 1, 5], [0, 3, 3, 5, 9, 10]))\n",
"print(rt)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BqAfbkAC56m0"
},
"source": [
"### Ragged rank and flat values\n",
"\n",
"A ragged tensor's ***ragged rank*** is the number of times that the underlying `values` tensor has been partitioned (i.e. the nesting depth of `RaggedTensor` objects). The innermost `values` tensor is known as its ***flat_values***. In the following example, `conversations` has ragged_rank=3, and its `flat_values` is a 1D `Tensor` with 24 strings:\n"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:55.042186Z",
"iopub.status.busy": "2024-08-27T01:20:55.041804Z",
"iopub.status.idle": "2024-08-27T01:20:55.048541Z",
"shell.execute_reply": "2024-08-27T01:20:55.047965Z"
},
"id": "BXp-Tt2bClem"
},
"outputs": [
{
"data": {
"text/plain": [
"TensorShape([2, None, None, None])"
]
},
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# shape = [batch, (paragraph), (sentence), (word)]\n",
"conversations = tf.ragged.constant(\n",
" [[[[\"I\", \"like\", \"ragged\", \"tensors.\"]],\n",
" [[\"Oh\", \"yeah?\"], [\"What\", \"can\", \"you\", \"use\", \"them\", \"for?\"]],\n",
" [[\"Processing\", \"variable\", \"length\", \"data!\"]]],\n",
" [[[\"I\", \"like\", \"cheese.\"], [\"Do\", \"you?\"]],\n",
" [[\"Yes.\"], [\"I\", \"do.\"]]]])\n",
"conversations.shape"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:55.051644Z",
"iopub.status.busy": "2024-08-27T01:20:55.051219Z",
"iopub.status.idle": "2024-08-27T01:20:55.055476Z",
"shell.execute_reply": "2024-08-27T01:20:55.054846Z"
},
"id": "DZUMrgxXFd5s"
},
"outputs": [
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": 87,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"assert conversations.ragged_rank == len(conversations.nested_row_splits)\n",
"conversations.ragged_rank # Number of partitioned dimensions."
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:55.058564Z",
"iopub.status.busy": "2024-08-27T01:20:55.058050Z",
"iopub.status.idle": "2024-08-27T01:20:55.062248Z",
"shell.execute_reply": "2024-08-27T01:20:55.061704Z"
},
"id": "xXLSNpS0Fdvp"
},
"outputs": [
{
"data": {
"text/plain": [
"array([b'I', b'like', b'ragged', b'tensors.', b'Oh', b'yeah?', b'What',\n",
" b'can', b'you', b'use', b'them', b'for?', b'Processing',\n",
" b'variable', b'length', b'data!', b'I', b'like', b'cheese.', b'Do',\n",
" b'you?', b'Yes.', b'I', b'do.'], dtype=object)"
]
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conversations.flat_values.numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uba2EnAY-PWf"
},
"source": [
"### Uniform inner dimensions\n",
"\n",
"Ragged tensors with uniform inner dimensions are encoded by using a\n",
"multidimensional `tf.Tensor` for the flat_values (i.e., the innermost `values`).\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:55.065508Z",
"iopub.status.busy": "2024-08-27T01:20:55.064982Z",
"iopub.status.idle": "2024-08-27T01:20:55.074907Z",
"shell.execute_reply": "2024-08-27T01:20:55.074317Z"
},
"id": "z2sHwHdy-PWg"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Shape: (3, None, 2)\n",
"Number of partitioned dimensions: 1\n",
"Flat values shape: (6, 2)\n",
"Flat values:\n",
"[[1 3]\n",
" [0 0]\n",
" [1 3]\n",
" [5 3]\n",
" [3 3]\n",
" [1 2]]\n"
]
}
],
"source": [
"rt = tf.RaggedTensor.from_row_splits(\n",
" values=[[1, 3], [0, 0], [1, 3], [5, 3], [3, 3], [1, 2]],\n",
" row_splits=[0, 3, 4, 6])\n",
"print(rt)\n",
"print(\"Shape: {}\".format(rt.shape))\n",
"print(\"Number of partitioned dimensions: {}\".format(rt.ragged_rank))\n",
"print(\"Flat values shape: {}\".format(rt.flat_values.shape))\n",
"print(\"Flat values:\\n{}\".format(rt.flat_values))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WoGRKd50x_qz"
},
"source": [
"### Uniform non-inner dimensions\n",
"\n",
"Ragged tensors with uniform non-inner dimensions are encoded by partitioning rows with `uniform_row_length`.\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-27T01:20:55.078318Z",
"iopub.status.busy": "2024-08-27T01:20:55.077753Z",
"iopub.status.idle": "2024-08-27T01:20:55.090097Z",
"shell.execute_reply": "2024-08-27T01:20:55.089514Z"
},
"id": "70q1aCKwySgS"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Shape: (2, 2, None)\n",
"Number of partitioned dimensions: 2\n"
]
}
],
"source": [
"rt = tf.RaggedTensor.from_uniform_row_length(\n",
" values=tf.RaggedTensor.from_row_splits(\n",
" values=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],\n",
" row_splits=[0, 3, 5, 9, 10]),\n",
" uniform_row_length=2)\n",
"print(rt)\n",
"print(\"Shape: {}\".format(rt.shape))\n",
"print(\"Number of partitioned dimensions: {}\".format(rt.ragged_rank))"
]
}
],
"metadata": {
"colab": {
"name": "ragged_tensor.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 0
}