HTTP/2 200
content-type: application/octet-stream
x-guploader-uploadid: ABgVH88qYplefw06b3RFRhiANA3IFXrSeHFcM7XxVYgmk2S1Ppmw7-DdERdLAowfyaUZrglC
expires: Tue, 15 Jul 2025 20:53:54 GMT
date: Tue, 15 Jul 2025 19:53:54 GMT
cache-control: public, max-age=3600
last-modified: Thu, 15 Aug 2024 03:15:15 GMT
etag: "c2bfc5e67c676c4c6c248839bcfdc521"
x-goog-generation: 1723691715494401
x-goog-metageneration: 1
x-goog-stored-content-encoding: identity
x-goog-stored-content-length: 156551
x-goog-hash: crc32c=7C32iA==
x-goog-hash: md5=wr/F5nxnbExsJIg5vP3FIQ==
x-goog-storage-class: MULTI_REGIONAL
accept-ranges: bytes
content-length: 156551
server: UploadServer
alt-svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "ISubpr_SSsiM"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cellView": "form",
"execution": {
"iopub.execute_input": "2024-08-15T02:57:28.696660Z",
"iopub.status.busy": "2024-08-15T02:57:28.696430Z",
"iopub.status.idle": "2024-08-15T02:57:28.700350Z",
"shell.execute_reply": "2024-08-15T02:57:28.699790Z"
},
"id": "3jTMb1dySr3V"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6DWfyNThSziV"
},
"source": [
"# Better performance with tf.function\n",
"\n",
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "J122XQYG7W6w"
},
"source": [
"In TensorFlow 2, [eager execution](basics.ipynb) is turned on by default. The user interface is intuitive and flexible (running one-off operations is much easier and faster), but this can come at the expense of performance and deployability.\n",
"\n",
"You can use `tf.function` to make graphs out of your programs. It is a transformation tool that creates Python-independent dataflow graphs out of your Python code. This will help you create performant and portable models, and it is required to use `SavedModel`.\n",
"\n",
"This guide will help you conceptualize how `tf.function` works under the hood, so you can use it effectively.\n",
"\n",
"The main takeaways and recommendations are:\n",
"\n",
"- Debug in eager mode, then decorate with `@tf.function`.\n",
"- Don't rely on Python side effects like object mutation or list appends.\n",
"- `tf.function` works best with TensorFlow ops; NumPy and Python calls are converted to constants.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SjvqpgepHJPd"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:28.703975Z",
"iopub.status.busy": "2024-08-15T02:57:28.703417Z",
"iopub.status.idle": "2024-08-15T02:57:31.078374Z",
"shell.execute_reply": "2024-08-15T02:57:31.077653Z"
},
"id": "otIdN1TS8N7S"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-08-15 02:57:28.958444: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-08-15 02:57:28.979712: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-08-15 02:57:28.986177: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
]
}
],
"source": [
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "I0xDjO4SHLUD"
},
"source": [
"Define a helper function to demonstrate the kinds of errors you might encounter:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:31.082294Z",
"iopub.status.busy": "2024-08-15T02:57:31.081923Z",
"iopub.status.idle": "2024-08-15T02:57:31.086568Z",
"shell.execute_reply": "2024-08-15T02:57:31.085929Z"
},
"id": "D25apou9IOXa"
},
"outputs": [],
"source": [
"import traceback\n",
"import contextlib\n",
"\n",
"# Some helper code to demonstrate the kinds of errors you might encounter.\n",
"@contextlib.contextmanager\n",
"def assert_raises(error_class):\n",
" try:\n",
" yield\n",
" except error_class as e:\n",
" print('Caught expected exception \\n {}:'.format(error_class))\n",
" traceback.print_exc(limit=2)\n",
" except Exception as e:\n",
" raise e\n",
" else:\n",
" raise Exception('Expected {} to be raised but no error was raised!'.format(\n",
" error_class))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WPSfepzTHThq"
},
"source": [
"## Basics"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CNwYTIJ8r56W"
},
"source": [
"### Usage\n",
"\n",
"A `tf.function` that you define (for example by applying the `@tf.function` decorator) is just like a core TensorFlow operation: You can execute it eagerly; you can compute gradients; and so on."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:31.089649Z",
"iopub.status.busy": "2024-08-15T02:57:31.089412Z",
"iopub.status.idle": "2024-08-15T02:57:33.330353Z",
"shell.execute_reply": "2024-08-15T02:57:33.329288Z"
},
"id": "SbtT1-Wm70F2"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
"I0000 00:00:1723690651.607368 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690651.611235 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690651.614398 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690651.618234 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690651.629890 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690651.633433 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690651.636337 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690651.639748 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690651.643233 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690651.646588 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690651.649526 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690651.652949 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.865955 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.868101 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.870112 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.872121 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.874165 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.876153 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.878068 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.879960 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.881883 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.883841 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.885768 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.887660 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.926250 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.928321 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.930298 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.932288 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.934241 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.936253 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.938172 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.940080 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.942041 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.944593 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.946947 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723690652.949245 167534 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n"
]
},
{
"data": {
"text/plain": [
"
"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@tf.function # The decorator converts `add` into a `PolymorphicFunction`.\n",
"def add(a, b):\n",
" return a + b\n",
"\n",
"add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:33.333783Z",
"iopub.status.busy": "2024-08-15T02:57:33.333536Z",
"iopub.status.idle": "2024-08-15T02:57:33.374824Z",
"shell.execute_reply": "2024-08-15T02:57:33.374164Z"
},
"id": "uP-zUelB8DbX"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v = tf.Variable(1.0)\n",
"with tf.GradientTape() as tape:\n",
" result = add(v, 1.0)\n",
"tape.gradient(result, v)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ocWZvqrmHnmX"
},
"source": [
"You can use `tf.function`s inside other `tf.function`s."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:33.378190Z",
"iopub.status.busy": "2024-08-15T02:57:33.377954Z",
"iopub.status.idle": "2024-08-15T02:57:33.469949Z",
"shell.execute_reply": "2024-08-15T02:57:33.469352Z"
},
"id": "l5qRjdbBVdU6"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@tf.function\n",
"def dense_layer(x, w, b):\n",
" return add(tf.matmul(x, w), b)\n",
"\n",
"dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "piBhz7gYsHqU"
},
"source": [
"`tf.function`s can be faster than eager code, especially for graphs with many small ops. But for graphs with a few expensive ops (like convolutions), you may not see much speedup.\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:33.473063Z",
"iopub.status.busy": "2024-08-15T02:57:33.472829Z",
"iopub.status.idle": "2024-08-15T02:57:34.579714Z",
"shell.execute_reply": "2024-08-15T02:57:34.578630Z"
},
"id": "zuXt4wRysI03"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"W0000 00:00:1723690654.228267 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n",
"W0000 00:00:1723690654.285525 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n",
"W0000 00:00:1723690654.290477 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n",
"W0000 00:00:1723690654.295072 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n",
"W0000 00:00:1723690654.299820 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n",
"W0000 00:00:1723690654.304580 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n",
"W0000 00:00:1723690654.322737 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n",
"W0000 00:00:1723690654.327483 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n",
"W0000 00:00:1723690654.332646 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n",
"W0000 00:00:1723690654.337747 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n",
"W0000 00:00:1723690654.343046 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n",
"W0000 00:00:1723690654.347480 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n",
"W0000 00:00:1723690654.361780 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n",
"W0000 00:00:1723690654.370325 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n",
"W0000 00:00:1723690654.381185 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n",
"W0000 00:00:1723690654.405763 167534 gpu_timer.cc:114] Skipping the delay kernel, measurement accuracy will be reduced\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Eager conv: 0.011224052000216034\n",
"Function conv: 0.005400947000453016\n",
"Note how there's not much difference in performance for convolutions\n"
]
}
],
"source": [
"import timeit\n",
"conv_layer = tf.keras.layers.Conv2D(100, 3)\n",
"\n",
"@tf.function\n",
"def conv_fn(image):\n",
" return conv_layer(image)\n",
"\n",
"image = tf.zeros([1, 200, 200, 100])\n",
"# Warm up\n",
"conv_layer(image); conv_fn(image)\n",
"print(\"Eager conv:\", timeit.timeit(lambda: conv_layer(image), number=10))\n",
"print(\"Function conv:\", timeit.timeit(lambda: conv_fn(image), number=10))\n",
"print(\"Note how there's not much difference in performance for convolutions\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uZ4Do2AV80cO"
},
"source": [
"### Tracing\n",
"\n",
"This section exposes how `tf.function` works under the hood, including implementation details *which may change in the future*. However, once you understand why and when tracing happens, it's much easier to use `tf.function` effectively!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nhpUtRqsXoyM"
},
"source": [
"#### What is \"tracing\"?\n",
"\n",
"A `tf.function` runs your program in a [TensorFlow Graph](https://www.tensorflow.org/guide/intro_to_graphs#what_are_graphs). However, a `tf.Graph` cannot represent all the things that you'd write in an eager TensorFlow program. For instance, Python supports polymorphism, but `tf.Graph` requires its inputs to have a specified data type and dimension. Or you may perform side tasks like reading command-line arguments, raising an error, or working with a more complex Python object; none of these things can run in a `tf.Graph`.\n",
"\n",
"`tf.function` bridges this gap by separating your code in two stages:\n",
"\n",
" 1) In the first stage, referred to as \"**tracing**\", `tf.function` creates a new `tf.Graph`. Python code runs normally, but all TensorFlow operations (like adding two Tensors) are *deferred*: they are captured by the `tf.Graph` and not run.\n",
"\n",
" 2) In the second stage, a `tf.Graph` which contains everything that was deferred in the first stage is run. This stage is much faster than the tracing stage.\n",
"\n",
"Depending on its inputs, `tf.function` will not always run the first stage when it is called. See [\"Rules of tracing\"](#rules_of_tracing) below to get a better sense of how it makes that determination. Skipping the first stage and only executing the second stage is what gives you TensorFlow's high performance.\n",
"\n",
"When `tf.function` does decide to trace, the tracing stage is immediately followed by the second stage, so calling the `tf.function` both creates and runs the `tf.Graph`. Later you will see how you can run only the tracing stage with [`get_concrete_function`](#obtaining_concrete_functions)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K7scSzLx662f"
},
"source": [
"When you pass arguments of different types into a `tf.function`, both stages are run:\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:34.583531Z",
"iopub.status.busy": "2024-08-15T02:57:34.583283Z",
"iopub.status.idle": "2024-08-15T02:57:34.657236Z",
"shell.execute_reply": "2024-08-15T02:57:34.656550Z"
},
"id": "kojmJrgq8U9v"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracing with Tensor(\"a:0\", shape=(), dtype=int32)\n",
"tf.Tensor(2, shape=(), dtype=int32)\n",
"\n",
"Tracing with Tensor(\"a:0\", shape=(), dtype=float32)\n",
"tf.Tensor(2.2, shape=(), dtype=float32)\n",
"\n",
"Tracing with Tensor(\"a:0\", shape=(), dtype=string)\n",
"tf.Tensor(b'aa', shape=(), dtype=string)\n",
"\n"
]
}
],
"source": [
"@tf.function\n",
"def double(a):\n",
" print(\"Tracing with\", a)\n",
" return a + a\n",
"\n",
"print(double(tf.constant(1)))\n",
"print()\n",
"print(double(tf.constant(1.1)))\n",
"print()\n",
"print(double(tf.constant(\"a\")))\n",
"print()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QPfouGUQrcNb"
},
"source": [
"Note that if you repeatedly call a `tf.function` with the same argument type, TensorFlow will skip the tracing stage and reuse a previously traced graph, as the generated graph would be identical."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:34.660485Z",
"iopub.status.busy": "2024-08-15T02:57:34.660223Z",
"iopub.status.idle": "2024-08-15T02:57:34.664446Z",
"shell.execute_reply": "2024-08-15T02:57:34.663819Z"
},
"id": "hFccbWFRrsBp"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(b'bb', shape=(), dtype=string)\n"
]
}
],
"source": [
"# This doesn't print 'Tracing with ...'\n",
"print(double(tf.constant(\"b\")))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fgIO_XEzcB9o"
},
"source": [
"You can use `pretty_printed_concrete_signatures()` to see all of the available traces:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:34.667763Z",
"iopub.status.busy": "2024-08-15T02:57:34.667381Z",
"iopub.status.idle": "2024-08-15T02:57:34.671065Z",
"shell.execute_reply": "2024-08-15T02:57:34.670440Z"
},
"id": "IiQc4IKAb-NX"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input Parameters:\n",
" a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.int32, name=None)\n",
"Output Type:\n",
" TensorSpec(shape=(), dtype=tf.int32, name=None)\n",
"Captures:\n",
" None\n",
"\n",
"Input Parameters:\n",
" a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None)\n",
"Output Type:\n",
" TensorSpec(shape=(), dtype=tf.float32, name=None)\n",
"Captures:\n",
" None\n",
"\n",
"Input Parameters:\n",
" a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None)\n",
"Output Type:\n",
" TensorSpec(shape=(), dtype=tf.string, name=None)\n",
"Captures:\n",
" None\n"
]
}
],
"source": [
"print(double.pretty_printed_concrete_signatures())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rKQ92VEWI7n8"
},
"source": [
"So far, you've seen that `tf.function` creates a cached, dynamic dispatch layer over TensorFlow's graph tracing logic. To be more specific about the terminology:\n",
"\n",
"- A `tf.Graph` is the raw, language-agnostic, portable representation of a TensorFlow computation.\n",
"- Tracing is the process through which new `tf.Graph`s are generated from Python code.\n",
"- An instance of `tf.Graph` is specialized to the specific input types it was traced with. Differing types require retracing.\n",
"- Each traced `tf.Graph` has a corresponding `ConcreteFunction`.\n",
"- A `tf.function` manages a cache of `ConcreteFunction`s and picks the right one for your inputs.\n",
"- `tf.function` wraps the Python function that will be traced, returning a `tf.types.experimental.PolymorphicFunction` object.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "129-iRsPS-gY"
},
"source": [
"#### Rules of tracing\n",
"\n",
"When called, a `tf.function` first evaluates the type of each input argument using the `tf.types.experimental.TraceType` of each argument. This is used to construct a `tf.types.experimental.FunctionType` describing the signature of the desired `ConcreteFunction`. We compare this `FunctionType` to the `FunctionType`s of existing `ConcreteFunction`s. If a matching `ConcreteFunction` is found, the call is dispatched to it. If no match is found, a new `ConcreteFunction` is traced for the desired `FunctionType`.\n",
"\n",
"If multiple matches are found, the most specific signature is chosen. Matching is done by [subtyping](https://en.wikipedia.org/wiki/Subtyping), much like normal function calls in C++ or Java, for instance. For example, `TensorShape([1, 2])` is a subtype of `TensorShape([None, None])` and so a call to the tf.function with `TensorShape([1, 2])` can be dispatched to the `ConcreteFunction` produced with `TensorShape([None, None])` but if a `ConcreteFunction` with `TensorShape([1, None])` also exists then it will be prioritized since it is more specific.\n",
"\n",
"The `TraceType` is determined from input arguments as follows:\n",
"* For `Tensor`, the type is parameterized by the `Tensor`'s `dtype` and `shape`; ranked shapes are a subtype of unranked shapes; fixed dimensions are a subtype of unknown dimensions\n",
"* For `Variable`, the type is similar to `Tensor`, but also includes a unique resource ID of the variable, necessary to correctly wire control dependencies\n",
"* For Python primitive values, the type corresponds to the **value** itself. For example, the `TraceType` of the value `3` is `LiteralTraceType<3>`, not `int`.\n",
"* For Python ordered containers such as `list` and `tuple`, etc., the type is parameterized by the types of their elements; for example, the type of `[1, 2]` is `ListTraceType, LiteralTraceType<2>>` and the type for `[2, 1]` is `ListTraceType, LiteralTraceType<1>>` which is different.\n",
"* For Python mappings such as `dict`, the type is also a mapping from the same keys but to the types of values instead of the actual values. For example, the type of `{1: 2, 3: 4}`, is `MappingTraceType<>>, >>>`. However, unlike ordered containers, `{1: 2, 3: 4}` and `{3: 4, 1: 2}` have equivalent types.\n",
"* For Python objects which implement the `__tf_tracing_type__` method, the type is whatever that method returns.\n",
"* For any other Python objects, the type is a generic `TraceType`, and the matching precedure is:\n",
" * First it checks if the object is the same object used in the previous trace (using Python `id()` or `is`). Note that this will still match if the object has changed, so if you use Python objects as `tf.function` arguments it's best to use *immutable* ones.\n",
" * Next it checks if the object is equal to the object used in the previous trace (using Python `==`).\n",
" \n",
" Note that this procedure only keeps a [weakref](https://docs.python.org/3/library/weakref.html) to the object and hence only works as long as the object is in scope/not deleted.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GNNN4lgRzpIs"
},
"source": [
"Note: `TraceType` is based on the `tf.function` input parameters so changes to global and [free variables](https://docs.python.org/3/reference/executionmodel.html#binding-of-names) alone will not create a new trace. See [this section](#depending_on_python_global_and_free_variables) for recommended practices when dealing with Python global and free variables."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PEDwbumO32Wh"
},
"source": [
"### Controlling retracing\n",
"\n",
"Retracing, which is when your `tf.function` creates more than one trace, helps ensure that TensorFlow generates correct graphs for each set of inputs. However, tracing is an expensive operation! If your `tf.function` retraces a new graph for every call, you'll find that your code executes more slowly than if you didn't use `tf.function`.\n",
"\n",
"To control the tracing behavior, you can use the following techniques:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EUtycWJa34TT"
},
"source": [
"#### Pass a fixed `input_signature` to `tf.function`\n",
"\n",
"This forces `tf.function` to constrain itself to only one `tf.types.experimental.FunctionType` composed of the types enumerated by the `input_signature`. Calls that cannot be dispatched to this `FunctionType` will throw an error."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:34.674545Z",
"iopub.status.busy": "2024-08-15T02:57:34.674312Z",
"iopub.status.idle": "2024-08-15T02:57:35.382971Z",
"shell.execute_reply": "2024-08-15T02:57:35.382235Z"
},
"id": "_BDMIRmu1RGB"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracing with Tensor(\"x:0\", shape=(None,), dtype=int32)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([4 1], shape=(2,), dtype=int32)\n",
"Caught expected exception \n",
" :\n",
"Caught expected exception \n",
" :\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"/tmpfs/tmp/ipykernel_167534/3551158538.py\", line 8, in assert_raises\n",
" yield\n",
" File \"/tmpfs/tmp/ipykernel_167534/3657259638.py\", line 9, in \n",
" next_collatz(tf.constant([[1, 2], [3, 4]]))\n",
"TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2, 2), dtype=tf.int32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)).\n",
"Traceback (most recent call last):\n",
" File \"/tmpfs/tmp/ipykernel_167534/3551158538.py\", line 8, in assert_raises\n",
" yield\n",
" File \"/tmpfs/tmp/ipykernel_167534/3657259638.py\", line 13, in \n",
" next_collatz(tf.constant([1.0, 2.0]))\n",
"TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(2,), dtype=tf.float32, name=None) to TensorSpec(shape=(None,), dtype=tf.int32, name=None)`. Received args: (,) and kwargs: {} for signature: (x: TensorSpec(shape=(None,), dtype=tf.int32, name=None)).\n"
]
}
],
"source": [
"@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))\n",
"def next_collatz(x):\n",
" print(\"Tracing with\", x)\n",
" return tf.where(x % 2 == 0, x // 2, 3 * x + 1)\n",
"\n",
"print(next_collatz(tf.constant([1, 2])))\n",
"# You specified a 1-D tensor in the input signature, so this should fail.\n",
"with assert_raises(TypeError):\n",
" next_collatz(tf.constant([[1, 2], [3, 4]]))\n",
"\n",
"# You specified an int32 dtype in the input signature, so this should fail.\n",
"with assert_raises(TypeError):\n",
" next_collatz(tf.constant([1.0, 2.0]))\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ocxX-HVk7P2o"
},
"source": [
"#### Use unknown dimensions for flexibility\n",
"\n",
" Since TensorFlow matches tensors based on their shape, using a `None` dimension as a wildcard will allow `tf.function`s to reuse traces for variably-sized input. Variably-sized input can occur if you have sequences of different length, or images of different sizes for each batch. You can check out the [Transformer](https://www.tensorflow.org/text/tutorials/transformer) and [Deep Dream](../tutorials/generative/deepdream.ipynb) tutorials for examples."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:35.387540Z",
"iopub.status.busy": "2024-08-15T02:57:35.386926Z",
"iopub.status.idle": "2024-08-15T02:57:35.426841Z",
"shell.execute_reply": "2024-08-15T02:57:35.426096Z"
},
"id": "4Viun7dh7PmF"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracing with Tensor(\"x:0\", shape=(None,), dtype=int32)\n",
"tf.Tensor([1 2 3], shape=(3,), dtype=int32)\n",
"tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)\n"
]
}
],
"source": [
"@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))\n",
"def g(x):\n",
" print('Tracing with', x)\n",
" return x\n",
"\n",
"# No retrace!\n",
"print(g(tf.constant([1, 2, 3])))\n",
"print(g(tf.constant([1, 2, 3, 4, 5])))\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "37cc12f93cbd"
},
"source": [
"#### Use `reduce_retracing` for automatic flexibility\n",
"\n",
"When `reduce_retracing` is enabled, `tf.function` automatically identifies supertypes of the input types it is observing and chooses to trace more generalized graphs automatically. It is less efficient than setting the `input_signature` directly but useful when many types need to be supported."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:35.430174Z",
"iopub.status.busy": "2024-08-15T02:57:35.429607Z",
"iopub.status.idle": "2024-08-15T02:57:35.476705Z",
"shell.execute_reply": "2024-08-15T02:57:35.475994Z"
},
"id": "0403fae03a1f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracing with Tensor(\"x:0\", shape=(3,), dtype=int32)\n",
"tf.Tensor([1 2 3], shape=(3,), dtype=int32)\n",
"Tracing with Tensor(\"x:0\", shape=(None,), dtype=int32)\n",
"tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)\n",
"tf.Tensor([1 2 3 4 5 6 7], shape=(7,), dtype=int32)\n",
"tf.Tensor([1 2 3 4 5 6 7 8 9], shape=(9,), dtype=int32)\n"
]
}
],
"source": [
"@tf.function(reduce_retracing=True)\n",
"def g(x):\n",
" print('Tracing with', x)\n",
" return x\n",
"\n",
"# Traces once.\n",
"print(g(tf.constant([1, 2, 3])))\n",
"\n",
"# Traces again, but more generalized this time.\n",
"print(g(tf.constant([1, 2, 3, 4, 5])))\n",
"\n",
"# No more tracing!\n",
"print(g(tf.constant([1, 2, 3, 4, 5, 6, 7])))\n",
"print(g(tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9])))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AY5oiQN0XIyA"
},
"source": [
"#### Pass tensors instead of python literals\n",
"\n",
" Often, Python arguments are used to control hyperparameters and graph constructions - for example, `num_layers=10` or `training=True` or `nonlinearity='relu'`. So, if the Python argument changes, it makes sense that you'd have to retrace the graph.\n",
"\n",
" However, it's possible that a Python argument is not being used to control graph construction. In these cases, a change in the Python value can trigger needless retracing. Take, for example, this training loop, which AutoGraph will dynamically unroll. Despite the multiple traces, the generated graph is actually identical, so retracing is unnecessary."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:35.479971Z",
"iopub.status.busy": "2024-08-15T02:57:35.479460Z",
"iopub.status.idle": "2024-08-15T02:57:35.621722Z",
"shell.execute_reply": "2024-08-15T02:57:35.621115Z"
},
"id": "uydzR5JYUU8H"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Retracing occurs for different Python arguments.\n",
"Tracing with num_steps = 10\n",
"Executing with num_steps = 10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracing with num_steps = 20\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Executing with num_steps = 20\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Traces are reused for Tensor arguments.\n",
"Tracing with num_steps = Tensor(\"num_steps:0\", shape=(), dtype=int32)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Executing with num_steps = 10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Executing with num_steps = 20\n"
]
}
],
"source": [
"def train_one_step():\n",
" pass\n",
"\n",
"@tf.function\n",
"def train(num_steps):\n",
" print(\"Tracing with num_steps = \", num_steps)\n",
" tf.print(\"Executing with num_steps = \", num_steps)\n",
" for _ in tf.range(num_steps):\n",
" train_one_step()\n",
"\n",
"print(\"Retracing occurs for different Python arguments.\")\n",
"train(num_steps=10)\n",
"train(num_steps=20)\n",
"\n",
"print()\n",
"print(\"Traces are reused for Tensor arguments.\")\n",
"train(num_steps=tf.constant(10))\n",
"train(num_steps=tf.constant(20))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4pJqkDR_Q2wz"
},
"source": [
"If you need to force retracing, create a new `tf.function`. Separate `tf.function` objects are guaranteed not to share traces."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:35.625387Z",
"iopub.status.busy": "2024-08-15T02:57:35.624749Z",
"iopub.status.idle": "2024-08-15T02:57:35.667176Z",
"shell.execute_reply": "2024-08-15T02:57:35.666531Z"
},
"id": "uHp4ousu4DdN"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracing!\n",
"Executing\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracing!\n",
"Executing\n"
]
}
],
"source": [
"def f():\n",
" print('Tracing!')\n",
" tf.print('Executing')\n",
"\n",
"tf.function(f)()\n",
"tf.function(f)()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-tZoWrA6INvc"
},
"source": [
"#### Use the tracing protocol\n",
"\n",
"Where possible, you should prefer converting the Python type into a `tf.experimental.ExtensionType` instead. Moreover, the `TraceType` of an `ExtensionType` is the `tf.TypeSpec` associated with it. Therefore, if needed, you can simply override the default `tf.TypeSpec` to take control of an `ExtensionType`'s `Tracing Protocol`. Refer to the _Customizing the ExtensionType's TypeSpec_ section in the [Extension types](extension_type.ipynb) guide for details.\n",
"\n",
"Otherwise, for direct control over when `tf.function` should retrace in regards to a particular Python type, you can implement the `Tracing Protocol` for it yourself."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:35.670445Z",
"iopub.status.busy": "2024-08-15T02:57:35.670205Z",
"iopub.status.idle": "2024-08-15T02:57:35.745155Z",
"shell.execute_reply": "2024-08-15T02:57:35.744507Z"
},
"id": "gZkIh7UaIKc6"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@tf.function\n",
"def get_mixed_flavor(fruit_a, fruit_b):\n",
" return fruit_a.flavor + fruit_b.flavor\n",
"\n",
"class Fruit:\n",
" flavor = tf.constant([0, 0])\n",
"\n",
"class Apple(Fruit):\n",
" flavor = tf.constant([1, 2])\n",
"\n",
"class Mango(Fruit):\n",
" flavor = tf.constant([3, 4])\n",
"\n",
"# As described in the above rules, a generic TraceType for `Apple` and `Mango`\n",
"# is generated (and a corresponding ConcreteFunction is traced) but it fails to\n",
"# match the second function call since the first pair of Apple() and Mango()\n",
"# have gone out out of scope by then and deleted.\n",
"get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function\n",
"get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again\n",
"\n",
"# However, each subclass of the `Fruit` class has a fixed flavor, and you\n",
"# can reuse an existing traced concrete function if it was the same\n",
"# subclass. Avoiding such unnecessary tracing of concrete functions\n",
"# can have significant performance benefits.\n",
"\n",
"class FruitTraceType(tf.types.experimental.TraceType):\n",
" def __init__(self, fruit):\n",
" self.fruit_type = type(fruit)\n",
" self.fruit_value = fruit\n",
"\n",
" def is_subtype_of(self, other):\n",
" # True if self subtypes `other` and `other`'s type matches FruitTraceType.\n",
" return (type(other) is FruitTraceType and\n",
" self.fruit_type is other.fruit_type)\n",
"\n",
" def most_specific_common_supertype(self, others):\n",
" # `self` is the specific common supertype if all input types match it.\n",
" return self if all(self == other for other in others) else None\n",
"\n",
" def placeholder_value(self, placeholder_context=None):\n",
" # Use the fruit itself instead of the type for correct tracing.\n",
" return self.fruit_value\n",
"\n",
" def __eq__(self, other):\n",
" return type(other) is FruitTraceType and self.fruit_type == other.fruit_type\n",
"\n",
" def __hash__(self):\n",
" return hash(self.fruit_type)\n",
"\n",
"class FruitWithTraceType:\n",
"\n",
" def __tf_tracing_type__(self, context):\n",
" return FruitTraceType(self)\n",
"\n",
"class AppleWithTraceType(FruitWithTraceType):\n",
" flavor = tf.constant([1, 2])\n",
"\n",
"class MangoWithTraceType(FruitWithTraceType):\n",
" flavor = tf.constant([3, 4])\n",
"\n",
"# Now if you try calling it again:\n",
"get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Traces a new concrete function\n",
"get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Re-uses the traced concrete function"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "96IxS2WR37fF"
},
"source": [
"### Obtaining concrete functions\n",
"\n",
"Every time a function is traced, a new concrete function is created. You can directly obtain a concrete function, by using `get_concrete_function`.\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:35.748873Z",
"iopub.status.busy": "2024-08-15T02:57:35.748623Z",
"iopub.status.idle": "2024-08-15T02:57:35.754374Z",
"shell.execute_reply": "2024-08-15T02:57:35.753759Z"
},
"id": "mHg2CGtPQ3Hz"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Obtaining concrete trace\n",
"Executing traced function\n",
"tf.Tensor(b'aa', shape=(), dtype=string)\n",
"tf.Tensor(b'bb', shape=(), dtype=string)\n"
]
}
],
"source": [
"print(\"Obtaining concrete trace\")\n",
"double_strings = double.get_concrete_function(tf.constant(\"a\"))\n",
"print(\"Executing traced function\")\n",
"print(double_strings(tf.constant(\"a\")))\n",
"print(double_strings(a=tf.constant(\"b\")))\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:35.757421Z",
"iopub.status.busy": "2024-08-15T02:57:35.757188Z",
"iopub.status.idle": "2024-08-15T02:57:35.761822Z",
"shell.execute_reply": "2024-08-15T02:57:35.761178Z"
},
"id": "6IVZ-NVf9vsx"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(b'cc', shape=(), dtype=string)\n"
]
}
],
"source": [
"# You can also call get_concrete_function on an InputSpec\n",
"double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))\n",
"print(double_strings_from_inputspec(tf.constant(\"c\")))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iR4fVmG34xvF"
},
"source": [
"Printing a `ConcreteFunction` displays a summary of its input arguments (with types) and its output type."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:35.765170Z",
"iopub.status.busy": "2024-08-15T02:57:35.764561Z",
"iopub.status.idle": "2024-08-15T02:57:35.768171Z",
"shell.execute_reply": "2024-08-15T02:57:35.767556Z"
},
"id": "o3-JbkIk41r8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ConcreteFunction Input Parameters:\n",
" a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None)\n",
"Output Type:\n",
" TensorSpec(shape=(), dtype=tf.string, name=None)\n",
"Captures:\n",
" None\n"
]
}
],
"source": [
"print(double_strings)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QtqfvljZeuOV"
},
"source": [
"You can also directly retrieve a concrete function's signature."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:35.771468Z",
"iopub.status.busy": "2024-08-15T02:57:35.770938Z",
"iopub.status.idle": "2024-08-15T02:57:35.774462Z",
"shell.execute_reply": "2024-08-15T02:57:35.773863Z"
},
"id": "nzbrqFABe0zG"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(a: TensorSpec(shape=(), dtype=tf.string, name=None)) -> TensorSpec(shape=(), dtype=tf.string, name=None)\n"
]
}
],
"source": [
"print(double_strings.function_type)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lar5A_5m5IG1"
},
"source": [
"Using a concrete trace with incompatible types will throw an error"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:35.777721Z",
"iopub.status.busy": "2024-08-15T02:57:35.777192Z",
"iopub.status.idle": "2024-08-15T02:57:35.782447Z",
"shell.execute_reply": "2024-08-15T02:57:35.781837Z"
},
"id": "G5eeTK-T5KYj"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Caught expected exception \n",
" :\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py\", line 442, in bind_function_inputs\n",
" bound_arguments = function_type.bind_with_defaults(\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/core/function/polymorphism/function_type.py\", line 277, in bind_with_defaults\n",
" with_default_args[arg_name] = constraint.cast(\n",
"TypeError: Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None)\n",
"\n",
"The above exception was the direct cause of the following exception:\n",
"\n",
"Traceback (most recent call last):\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1179, in _call_impl\n",
" return self._call_with_structured_signature(args, kwargs)\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1259, in _call_with_structured_signature\n",
" function_type_utils.canonicalize_function_inputs(\n",
"TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec(shape=(), dtype=tf.int32, name=None) to TensorSpec(shape=(), dtype=tf.string, name=None)`. Received args: (,) and kwargs: {} for signature: (a: TensorSpec(shape=(), dtype=tf.string, name=None)) -> TensorSpec(shape=(), dtype=tf.string, name=None).\n",
"\n",
"During handling of the above exception, another exception occurred:\n",
"\n",
"Traceback (most recent call last):\n",
" File \"/tmpfs/tmp/ipykernel_167534/3551158538.py\", line 8, in assert_raises\n",
" yield\n",
" File \"/tmpfs/tmp/ipykernel_167534/3196284684.py\", line 2, in \n",
" double_strings(tf.constant(1))\n",
"tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_189 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_189]\n"
]
}
],
"source": [
"with assert_raises(tf.errors.InvalidArgumentError):\n",
" double_strings(tf.constant(1))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "st2L9VNQVtSG"
},
"source": [
"You may notice that Python arguments are given special treatment in a concrete function's input signature. Prior to TensorFlow 2.3, Python arguments were simply removed from the concrete function's signature. Starting with TensorFlow 2.3, Python arguments remain in the signature, but are constrained to take the value set during tracing."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:35.785445Z",
"iopub.status.busy": "2024-08-15T02:57:35.785218Z",
"iopub.status.idle": "2024-08-15T02:57:35.811770Z",
"shell.execute_reply": "2024-08-15T02:57:35.810954Z"
},
"id": "U_QyPSGoaC35"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ConcreteFunction Input Parameters:\n",
" a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=, dtype=tf.float32, name=None)\n",
" b (POSITIONAL_OR_KEYWORD): Literal[2]\n",
"Output Type:\n",
" TensorSpec(shape=, dtype=tf.float32, name=None)\n",
"Captures:\n",
" None\n"
]
}
],
"source": [
"@tf.function\n",
"def pow(a, b):\n",
" return a ** b\n",
"\n",
"square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)\n",
"print(square)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:35.814871Z",
"iopub.status.busy": "2024-08-15T02:57:35.814391Z",
"iopub.status.idle": "2024-08-15T02:57:35.897682Z",
"shell.execute_reply": "2024-08-15T02:57:35.897031Z"
},
"id": "E76vIDhQbXIb"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Caught expected exception \n",
" :\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/function_type_utils.py\", line 442, in bind_function_inputs\n",
" bound_arguments = function_type.bind_with_defaults(\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/core/function/polymorphism/function_type.py\", line 277, in bind_with_defaults\n",
" with_default_args[arg_name] = constraint.cast(\n",
"ValueError: Can not cast 3 to Literal[2]\n",
"\n",
"The above exception was the direct cause of the following exception:\n",
"\n",
"Traceback (most recent call last):\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1179, in _call_impl\n",
" return self._call_with_structured_signature(args, kwargs)\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1259, in _call_with_structured_signature\n",
" function_type_utils.canonicalize_function_inputs(\n",
"TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=, dtype=tf.float32, name=None).\n",
"\n",
"During handling of the above exception, another exception occurred:\n",
"\n",
"Traceback (most recent call last):\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1182, in _call_impl\n",
" return self._call_with_flat_signature(args, kwargs)\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py\", line 1233, in _call_with_flat_signature\n",
" raise TypeError(f\"{self._flat_signature_summary()} got unexpected \"\n",
"TypeError: pow(a) got unexpected keyword arguments: b.\n",
"\n",
"During handling of the above exception, another exception occurred:\n",
"\n",
"Traceback (most recent call last):\n",
" File \"/tmpfs/tmp/ipykernel_167534/3551158538.py\", line 8, in assert_raises\n",
" yield\n",
" File \"/tmpfs/tmp/ipykernel_167534/2310937119.py\", line 4, in \n",
" square(tf.constant(10.0), b=3)\n",
"TypeError: Binding inputs to tf.function failed due to `Can not cast 3 to Literal[2]`. Received args: (,) and kwargs: {'b': 3} for signature: (a: TensorSpec(shape=, dtype=tf.float32, name=None), b: Literal[2]) -> TensorSpec(shape=, dtype=tf.float32, name=None).\n",
"Fallback to flat signature also failed due to: pow(a) got unexpected keyword arguments: b.\n"
]
}
],
"source": [
"assert square(tf.constant(10.0)) == 100\n",
"\n",
"with assert_raises(TypeError):\n",
" square(tf.constant(10.0), b=3)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "41gJh_JGIfuA"
},
"source": [
"### Obtaining graphs\n",
"\n",
"Although retrieving the actual `tf.Graph` object is not something you'll normally need to do, you can obtain it easily from any concrete function."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:35.901860Z",
"iopub.status.busy": "2024-08-15T02:57:35.901436Z",
"iopub.status.idle": "2024-08-15T02:57:35.905754Z",
"shell.execute_reply": "2024-08-15T02:57:35.905104Z"
},
"id": "5UENeGHfaX8g"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[] -> a\n",
"['a', 'a'] -> add\n",
"['add'] -> Identity\n"
]
}
],
"source": [
"graph = double_strings.graph\n",
"for node in graph.as_graph_def().node:\n",
" print(f'{node.input} -> {node.name}')\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2d49c486ccd4"
},
"source": [
"In reality, `tf.Graph`s are not directly callable. We actually use an `tf.types.experimental.AtomicFunction` to perform the computations described by the `tf.Graph`. You can access the `AtomicFunction` describing the traced `tf.Graph` and call it directly instead of the `ConcreteFunction`:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:35.909226Z",
"iopub.status.busy": "2024-08-15T02:57:35.908799Z",
"iopub.status.idle": "2024-08-15T02:57:35.913615Z",
"shell.execute_reply": "2024-08-15T02:57:35.913066Z"
},
"id": "4c3879aa0be0"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"atomic_fn = double_strings.inference_fn\n",
"atomic_fn(tf.constant(\"a\"))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c3bd1036c18c"
},
"source": [
"This has the advantage of having lower Python overhead for high-performance scenarios. But it should only be used for forward inference (no gradient support), and captured tensor values (if any) would need to be explicitly supplied."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aIKkgr6qdtp4"
},
"source": [
"### Debugging\n",
"\n",
"In general, debugging code is easier in eager mode than inside `tf.function`. You should ensure that your code executes error-free in eager mode before decorating with `tf.function`. To assist in the debugging process, you can call `tf.config.run_functions_eagerly(True)` to globally disable and reenable `tf.function`.\n",
"\n",
"When tracking down issues that only appear within `tf.function`, here are some tips:\n",
"- Plain old Python `print` calls only execute during tracing, helping you track down when your function gets (re)traced.\n",
"- `tf.print` calls will execute every time, and can help you track down intermediate values during execution.\n",
"- `tf.debugging.enable_check_numerics` is an easy way to track down where NaNs and Inf are created.\n",
"- `pdb` (the [Python debugger](https://docs.python.org/3/library/pdb.html)) can help you understand what's going on during tracing. (Caveat: `pdb` will drop you into AutoGraph-transformed source code.)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5f05Vr_YBUCz"
},
"source": [
"## AutoGraph transformations\n",
"\n",
"AutoGraph is a library that is on by default in `tf.function`, and transforms a subset of Python eager code into graph-compatible TensorFlow ops. This includes control flow like `if`, `for`, `while`.\n",
"\n",
"TensorFlow ops like `tf.cond` and `tf.while_loop` continue to work, but control flow is often easier to write and understand when written in Python."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:35.917002Z",
"iopub.status.busy": "2024-08-15T02:57:35.916766Z",
"iopub.status.idle": "2024-08-15T02:57:36.036176Z",
"shell.execute_reply": "2024-08-15T02:57:36.035499Z"
},
"id": "yCQTtTPTW3WF"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.722626925 0.640327692 0.725044 0.904435039 0.868018746]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.61853379 0.565122604 0.620023966 0.718450606 0.700366139]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.550106347 0.511768281 0.551144719 0.615948677 0.604600191]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.500599921 0.471321791 0.501377642 0.548301 0.540314913]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.462588847 0.439266682 0.463199914 0.499245733 0.493226349]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.432191819 0.413036436 0.432688653 0.461523771 0.456773371]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.407151431 0.391047835 0.407565802 0.431325316 0.427450746]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.386051297 0.372263193 0.386403859 0.406428277 0.403188676]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.367951065 0.355969697 0.368255854 0.38543576 0.382673979]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.352198243 0.341659099 0.352465183 0.367418766 0.365027398]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.338323593 0.328957736 0.33856 0.351731867 0.349634588]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.325979948 0.317583948 0.326191217 0.337910533 0.336051434]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.314903945 0.307320684 0.315094262 0.325610697 0.323947728]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.304891765 0.297997624 0.30506441 0.314571291 0.313072115]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.295782804 0.289479077 0.29594034 0.304590017 0.303229302]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.287448555 0.281655282 0.287593067 0.295507431 0.294265062]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.279784769 0.274436355 0.279917955 0.287195921 0.286055595]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.272705853 0.267748028 0.272829145 0.279551893 0.278500348]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.266140789 0.261528105 0.266255379 0.272490293 0.271516532]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.26003018 0.255724251 0.260137022 0.265940517 0.265035421]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.254323781 0.250291914 0.254423678 0.259843439 0.258999288]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.248978764 0.245193034 0.249072418 0.25414905 0.253359258]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.243958414 0.240394741 0.244046524 0.248814836 0.248073786]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.239231125 0.235868543 0.239314198 0.243804231 0.24310714]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.234769359 0.231589615 0.234847859 0.239085764 0.238428399]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.230549142 0.227536201 0.230623439 0.234632015 0.234010741]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.226549357 0.223689109 0.22661984 0.23041907 0.229830697]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.222751439 0.220031396 0.222818434 0.226425976 0.225867674]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.21913895 0.216548 0.219202697 0.222634196 0.222103462]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.215697214 0.213225439 0.215757981 0.219027311 0.218521982]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.212413162 0.210051686 0.212471202 0.215590775 0.215108871]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.209275112 0.207015961 0.209330618 0.212311521 0.211851314]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.206272557 0.204108506 0.206325665 0.209177911 0.20873782]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.203395993 0.201320544 0.203446865 0.206179485 0.20575805]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.200636819 0.198644072 0.200685605 0.203306749 0.202902704]\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# A simple loop\n",
"\n",
"@tf.function\n",
"def f(x):\n",
" while tf.reduce_sum(x) > 1:\n",
" tf.print(x)\n",
" x = tf.tanh(x)\n",
" return x\n",
"\n",
"f(tf.random.uniform([5]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KxwJ8znPI0Cg"
},
"source": [
"If you're curious you can inspect the code AutoGraph generates."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:36.039171Z",
"iopub.status.busy": "2024-08-15T02:57:36.038913Z",
"iopub.status.idle": "2024-08-15T02:57:36.043686Z",
"shell.execute_reply": "2024-08-15T02:57:36.043110Z"
},
"id": "jlQD1ffRXJhl"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def tf__f(x):\n",
" with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:\n",
" do_return = False\n",
" retval_ = ag__.UndefinedReturnValue()\n",
"\n",
" def get_state():\n",
" return (x,)\n",
"\n",
" def set_state(vars_):\n",
" nonlocal x\n",
" (x,) = vars_\n",
"\n",
" def loop_body():\n",
" nonlocal x\n",
" ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)\n",
" x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)\n",
"\n",
" def loop_test():\n",
" return ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1\n",
" ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})\n",
" try:\n",
" do_return = True\n",
" retval_ = ag__.ld(x)\n",
" except:\n",
" do_return = False\n",
" raise\n",
" return fscope.ret(retval_, do_return)\n",
"\n"
]
}
],
"source": [
"print(tf.autograph.to_code(f.python_function))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xgKmkrNTZSyz"
},
"source": [
"### Conditionals\n",
"\n",
"AutoGraph will convert some `if ` statements into the equivalent `tf.cond` calls. This substitution is made if `` is a Tensor. Otherwise, the `if` statement is executed as a Python conditional.\n",
"\n",
"A Python conditional executes during tracing, so exactly one branch of the conditional will be added to the graph. Without AutoGraph, this traced graph would be unable to take the alternate branch if there is data-dependent control flow.\n",
"\n",
"`tf.cond` traces and adds both branches of the conditional to the graph, dynamically selecting a branch at execution time. Tracing can have unintended side effects; check out [AutoGraph tracing effects](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#effects-of-the-tracing-process) for more information."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:36.046725Z",
"iopub.status.busy": "2024-08-15T02:57:36.046453Z",
"iopub.status.idle": "2024-08-15T02:57:36.248778Z",
"shell.execute_reply": "2024-08-15T02:57:36.248143Z"
},
"id": "BOQl8PMq2Sf3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracing for loop\n",
"Tracing fizzbuzz branch\n",
"Tracing fizz branch\n",
"Tracing buzz branch\n",
"Tracing default branch\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fizz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"4\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"buzz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fizz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"4\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"buzz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fizz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"7\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"8\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fizz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"buzz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"11\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fizz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"13\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"14\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fizzbuzz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"16\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"17\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"fizz\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"19\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"buzz\n"
]
}
],
"source": [
"@tf.function\n",
"def fizzbuzz(n):\n",
" for i in tf.range(1, n + 1):\n",
" print('Tracing for loop')\n",
" if i % 15 == 0:\n",
" print('Tracing fizzbuzz branch')\n",
" tf.print('fizzbuzz')\n",
" elif i % 3 == 0:\n",
" print('Tracing fizz branch')\n",
" tf.print('fizz')\n",
" elif i % 5 == 0:\n",
" print('Tracing buzz branch')\n",
" tf.print('buzz')\n",
" else:\n",
" print('Tracing default branch')\n",
" tf.print(i)\n",
"\n",
"fizzbuzz(tf.constant(5))\n",
"fizzbuzz(tf.constant(20))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4rBO5AQ15HVC"
},
"source": [
"See the [reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#if-statements) for additional restrictions on AutoGraph-converted if statements."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yho4J0a0ZkQS"
},
"source": [
"### Loops\n",
"\n",
"AutoGraph will convert some `for` and `while` statements into the equivalent TensorFlow looping ops, like `tf.while_loop`. If not converted, the `for` or `while` loop is executed as a Python loop.\n",
"\n",
"This substitution is made in the following situations:\n",
"\n",
"- `for x in y`: if `y` is a Tensor, convert to `tf.while_loop`. In the special case where `y` is a `tf.data.Dataset`, a combination of `tf.data.Dataset` ops are generated.\n",
"- `while `: if `` is a Tensor, convert to `tf.while_loop`.\n",
"\n",
"A Python loop executes during tracing, adding additional ops to the `tf.Graph` for every iteration of the loop.\n",
"\n",
"A TensorFlow loop traces the body of the loop, and dynamically selects how many iterations to run at execution time. The loop body only appears once in the generated `tf.Graph`.\n",
"\n",
"See the [reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#while-statements) for additional restrictions on AutoGraph-converted `for` and `while` statements."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sp4rbIdfbM6s"
},
"source": [
"#### Looping over Python data\n",
"\n",
"A common pitfall is to loop over Python/NumPy data within a `tf.function`. This loop will execute during the tracing process, adding a copy of your model to the `tf.Graph` for each iteration of the loop.\n",
"\n",
"If you want to wrap the entire training loop in `tf.function`, the safest way to do this is to wrap your data as a `tf.data.Dataset` so that AutoGraph will dynamically unroll the training loop."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:36.252292Z",
"iopub.status.busy": "2024-08-15T02:57:36.252044Z",
"iopub.status.idle": "2024-08-15T02:57:36.399324Z",
"shell.execute_reply": "2024-08-15T02:57:36.398619Z"
},
"id": "WGZ19LspbZ27"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph\n",
"train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train(<_FlatMapDataset element_spec=(TensorSpec(shape=, dtype=tf.int32, name=None), TensorSpec(shape=, dtype=tf.int32, name=None))>) contains 6 nodes in its graph\n",
"train(<_FlatMapDataset element_spec=(TensorSpec(shape=, dtype=tf.int32, name=None), TensorSpec(shape=, dtype=tf.int32, name=None))>) contains 6 nodes in its graph\n"
]
}
],
"source": [
"def measure_graph_size(f, *args):\n",
" g = f.get_concrete_function(*args).graph\n",
" print(\"{}({}) contains {} nodes in its graph\".format(\n",
" f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))\n",
"\n",
"@tf.function\n",
"def train(dataset):\n",
" loss = tf.constant(0)\n",
" for x, y in dataset:\n",
" loss += tf.abs(y - x) # Some dummy computation.\n",
" return loss\n",
"\n",
"small_data = [(1, 1)] * 3\n",
"big_data = [(1, 1)] * 10\n",
"measure_graph_size(train, small_data)\n",
"measure_graph_size(train, big_data)\n",
"\n",
"measure_graph_size(train, tf.data.Dataset.from_generator(\n",
" lambda: small_data, (tf.int32, tf.int32)))\n",
"measure_graph_size(train, tf.data.Dataset.from_generator(\n",
" lambda: big_data, (tf.int32, tf.int32)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JeD2U-yrbfVb"
},
"source": [
"When wrapping Python/NumPy data in a Dataset, be mindful of `tf.data.Dataset.from_generator` versus ` tf.data.Dataset.from_tensor_slices`. The former will keep the data in Python and fetch it via `tf.py_function` which can have performance implications, whereas the latter will bundle a copy of the data as one large `tf.constant()` node in the graph, which can have memory implications.\n",
"\n",
"Reading data from files via `TFRecordDataset`, `CsvDataset`, etc. is the most effective way to consume data, as then TensorFlow itself can manage the asynchronous loading and prefetching of data, without having to involve Python. To learn more, see the [`tf.data`: Build TensorFlow input pipelines](data.ipynb) guide."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hyksHW9TCukR"
},
"source": [
"#### Accumulating values in a loop\n",
"\n",
"A common pattern is to accumulate intermediate values from a loop. Normally, this is accomplished by appending to a Python list or adding entries to a Python dictionary. However, as these are Python side effects, they will not work as expected in a dynamically unrolled loop. Use `tf.TensorArray` to accumulate results from a dynamically unrolled loop."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:36.402646Z",
"iopub.status.busy": "2024-08-15T02:57:36.402408Z",
"iopub.status.idle": "2024-08-15T02:57:36.556618Z",
"shell.execute_reply": "2024-08-15T02:57:36.555885Z"
},
"id": "HJ3Vb3dXfefN"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch_size = 2\n",
"seq_len = 3\n",
"feature_size = 4\n",
"\n",
"def rnn_step(inp, state):\n",
" return inp + state\n",
"\n",
"@tf.function\n",
"def dynamic_rnn(rnn_step, input_data, initial_state):\n",
" # [batch, time, features] -> [time, batch, features]\n",
" input_data = tf.transpose(input_data, [1, 0, 2])\n",
" max_seq_len = input_data.shape[0]\n",
"\n",
" states = tf.TensorArray(tf.float32, size=max_seq_len)\n",
" state = initial_state\n",
" for i in tf.range(max_seq_len):\n",
" state = rnn_step(input_data[i], state)\n",
" states = states.write(i, state)\n",
" return tf.transpose(states.stack(), [1, 0, 2])\n",
"\n",
"dynamic_rnn(rnn_step,\n",
" tf.random.uniform([batch_size, seq_len, feature_size]),\n",
" tf.zeros([batch_size, feature_size]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i2MVoIVaNApG"
},
"source": [
"## Limitations\n",
"\n",
"`tf.function` has a few limitations by design that you should be aware of when converting a Python function to a `tf.function`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EJqHGFSVLIKl"
},
"source": [
"### Executing Python side effects\n",
"\n",
"Side effects, like printing, appending to lists, and mutating globals, can behave unexpectedly inside a `tf.function`, sometimes executing twice or not all. They only happen the first time you call a `tf.function` with a set of inputs. Afterwards, the traced `tf.Graph` is reexecuted, without executing the Python code.\n",
"\n",
"The general rule of thumb is to avoid relying on Python side effects in your logic and only use them to debug your traces. Otherwise, TensorFlow APIs like `tf.data`, `tf.print`, `tf.summary`, `tf.Variable.assign`, and `tf.TensorArray` are the best way to ensure your code will be executed by the TensorFlow runtime with each call."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:36.559990Z",
"iopub.status.busy": "2024-08-15T02:57:36.559752Z",
"iopub.status.idle": "2024-08-15T02:57:36.607158Z",
"shell.execute_reply": "2024-08-15T02:57:36.606483Z"
},
"id": "w2sACuZ9TTRk"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Traced with 1\n",
"Executed with 1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Executed with 1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Traced with 2\n",
"Executed with 2\n"
]
}
],
"source": [
"@tf.function\n",
"def f(x):\n",
" print(\"Traced with\", x)\n",
" tf.print(\"Executed with\", x)\n",
"\n",
"f(1)\n",
"f(1)\n",
"f(2)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e1I0dPiqTV8H"
},
"source": [
"If you would like to execute Python code during each invocation of a `tf.function`, `tf. py_function` is an exit hatch. The drawbacks of `tf.py_function` are that it's not portable or particularly performant, cannot be saved with `SavedModel`, and does not work well in distributed (multi-GPU, TPU) setups. Also, since `tf.py_function` has to be wired into the graph, it casts all inputs/outputs to tensors."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:36.610565Z",
"iopub.status.busy": "2024-08-15T02:57:36.609971Z",
"iopub.status.idle": "2024-08-15T02:57:36.614047Z",
"shell.execute_reply": "2024-08-15T02:57:36.613491Z"
},
"id": "ZbI7XA_e6yA2"
},
"outputs": [],
"source": [
"@tf.py_function(Tout=tf.float32)\n",
"def py_plus(x, y):\n",
" print('Executing eagerly.')\n",
" return x + y\n",
"\n",
"@tf.function\n",
"def tf_wrapper(x, y):\n",
" print('Tracing.')\n",
" return py_plus(x, y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h5ttN_sI7TdQ"
},
"source": [
"The `tf.function` will trace the first time:"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:36.617094Z",
"iopub.status.busy": "2024-08-15T02:57:36.616845Z",
"iopub.status.idle": "2024-08-15T02:57:36.675821Z",
"shell.execute_reply": "2024-08-15T02:57:36.675203Z"
},
"id": "mAK4XINl7Ldy"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracing.\n",
"Executing eagerly.\n"
]
},
{
"data": {
"text/plain": [
"3.0"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Atxvrd_o7dSy"
},
"source": [
"But the `tf.py_function` inside executes eagerly every time:"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:36.679282Z",
"iopub.status.busy": "2024-08-15T02:57:36.678706Z",
"iopub.status.idle": "2024-08-15T02:57:36.684756Z",
"shell.execute_reply": "2024-08-15T02:57:36.684201Z"
},
"id": "vv7qTiTU7bjy"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Executing eagerly.\n"
]
},
{
"data": {
"text/plain": [
"3.0"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bOW1v9WVKGgH"
},
"source": [
"#### Changing Python global and free variables\n",
"\n",
"Changing Python global and [free variables](https://docs.python.org/3/reference/executionmodel.html#binding-of-names) counts as a Python side effect, so it only happens during tracing.\n"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:36.688171Z",
"iopub.status.busy": "2024-08-15T02:57:36.687601Z",
"iopub.status.idle": "2024-08-15T02:57:36.717181Z",
"shell.execute_reply": "2024-08-15T02:57:36.716551Z"
},
"id": "7aJD--9qTWmg"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Python side effect\n"
]
}
],
"source": [
"external_list = []\n",
"\n",
"@tf.function\n",
"def side_effect(x):\n",
" print('Python side effect')\n",
" external_list.append(x)\n",
"\n",
"side_effect(1)\n",
"side_effect(1)\n",
"side_effect(1)\n",
"# The list append only happened once!\n",
"assert len(external_list) == 1"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5eZTFRv_k_nR"
},
"source": [
"Sometimes unexpected behaviors are very hard to notice. In the example below, the `counter` is intended to safeguard the increment of a variable. However because it is a python integer and not a TensorFlow object, it's value is captured during the first trace. When the `tf.function` is used, the `assign_add` will be recorded unconditionally in the underlying graph. Therefore `v` will increase by 1, every time the `tf.function` is called. This issue is common among users that try to migrate their Graph-mode Tensorflow code to Tensorflow 2 using `tf.function` decorators, when python side-effects (the `counter` in the example) are used to determine what ops to run (`assign_add` in the example). Usually, users realize this only after seeing suspicious numerical results, or significantly lower performance than expected (e.g. if the guarded operation is very costly)."
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:36.720199Z",
"iopub.status.busy": "2024-08-15T02:57:36.719925Z",
"iopub.status.idle": "2024-08-15T02:57:36.772295Z",
"shell.execute_reply": "2024-08-15T02:57:36.771673Z"
},
"id": "5r6p7-9jk_3L"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"2\n",
"3\n"
]
}
],
"source": [
"class Model(tf.Module):\n",
" def __init__(self):\n",
" self.v = tf.Variable(0)\n",
" self.counter = 0\n",
"\n",
" @tf.function\n",
" def __call__(self):\n",
" if self.counter == 0:\n",
" # A python side-effect\n",
" self.counter += 1\n",
" self.v.assign_add(1)\n",
"\n",
" return self.v\n",
"\n",
"m = Model()\n",
"for n in range(3):\n",
" print(m().numpy()) # prints 1, 2, 3"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tXCTcHoVcxhX"
},
"source": [
"A workaround to achieve the expected behavior is using [`tf.init_scope`](https://www.tensorflow.org/api_docs/python/tf/init_scope) to lift the operations outside of the function graph. This ensures that the variable increment is only done once during tracing time. It should be noted `init_scope` has other side effects including cleared control flow and gradient tape. Sometimes the usage of `init_scope` can become too complex to manage realistically."
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:36.775282Z",
"iopub.status.busy": "2024-08-15T02:57:36.775033Z",
"iopub.status.idle": "2024-08-15T02:57:36.829201Z",
"shell.execute_reply": "2024-08-15T02:57:36.828605Z"
},
"id": "An4MrIbrcvi8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"1\n",
"1\n"
]
}
],
"source": [
"class Model(tf.Module):\n",
" def __init__(self):\n",
" self.v = tf.Variable(0)\n",
" self.counter = 0\n",
"\n",
" @tf.function\n",
" def __call__(self):\n",
" if self.counter == 0:\n",
" # Lifts ops out of function-building graphs\n",
" with tf.init_scope():\n",
" self.counter += 1\n",
" self.v.assign_add(1)\n",
"\n",
" return self.v\n",
"\n",
"m = Model()\n",
"for n in range(3):\n",
" print(m().numpy()) # prints 1, 1, 1"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pbFG5CX4LwQA"
},
"source": [
"In summary, as a rule of thumb, you should avoid mutating python objects such as integers or containers like lists that live outside the `tf.function`. Instead, use arguments and TF objects. For example, the section [\"Accumulating values in a loop\"](#accumulating_values_in_a_loop) has one example of how list-like operations can be implemented.\n",
"\n",
"You can, in some cases, capture and manipulate state if it is a [`tf.Variable`](https://www.tensorflow.org/guide/variable). This is how the weights of Keras models are updated with repeated calls to the same `ConcreteFunction`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X_oNNGrAqPJ1"
},
"source": [
"#### Using Python iterators and generators"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "msTmv-oyUNaf"
},
"source": [
"Many Python features, such as generators and iterators, rely on the Python runtime to keep track of state. In general, while these constructs work as expected in eager mode, they are examples of Python side effects and therefore only happen during tracing."
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:36.832417Z",
"iopub.status.busy": "2024-08-15T02:57:36.832185Z",
"iopub.status.idle": "2024-08-15T02:57:36.866634Z",
"shell.execute_reply": "2024-08-15T02:57:36.865968Z"
},
"id": "FNPD4unZUedH"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Value: 1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Value: 1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Value: 1\n"
]
}
],
"source": [
"@tf.function\n",
"def buggy_consume_next(iterator):\n",
" tf.print(\"Value:\", next(iterator))\n",
"\n",
"iterator = iter([1, 2, 3])\n",
"buggy_consume_next(iterator)\n",
"# This reuses the first value from the iterator, rather than consuming the next value.\n",
"buggy_consume_next(iterator)\n",
"buggy_consume_next(iterator)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wcS3TAgCjTWR"
},
"source": [
"Just like how TensorFlow has a specialized `tf.TensorArray` for list constructs, it has a specialized `tf.data.Iterator` for iteration constructs. See the section on [AutoGraph transformations](#autograph_transformations) for an overview. Also, the [`tf.data`](https://www.tensorflow.org/guide/data) API can help implement generator patterns:\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:36.870148Z",
"iopub.status.busy": "2024-08-15T02:57:36.869643Z",
"iopub.status.idle": "2024-08-15T02:57:36.915434Z",
"shell.execute_reply": "2024-08-15T02:57:36.914774Z"
},
"id": "8D_iKetXW6VE"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Value: 1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Value: 2\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Value: 3\n"
]
}
],
"source": [
"@tf.function\n",
"def good_consume_next(iterator):\n",
" # This is ok, iterator is a tf.data.Iterator\n",
" tf.print(\"Value:\", next(iterator))\n",
"\n",
"ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])\n",
"iterator = iter(ds)\n",
"good_consume_next(iterator)\n",
"good_consume_next(iterator)\n",
"good_consume_next(iterator)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i8YAMYb6KEh4"
},
"source": [
"### All outputs of a tf.function must be return values\n",
"\n",
"With the exception of `tf.Variable`s, a tf.function must return all its\n",
"outputs. Attempting to directly access any tensors from a function without\n",
"going through return values causes \"leaks\".\n",
"\n",
"For example, the function below \"leaks\" the tensor `a` through the Python\n",
"global `x`:"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:36.918565Z",
"iopub.status.busy": "2024-08-15T02:57:36.918342Z",
"iopub.status.idle": "2024-08-15T02:57:36.957213Z",
"shell.execute_reply": "2024-08-15T02:57:36.956616Z"
},
"id": "zrdp4rjxg6jo"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3\n",
"'SymbolicTensor' object has no attribute 'numpy'\n"
]
}
],
"source": [
"x = None\n",
"\n",
"@tf.function\n",
"def leaky_function(a):\n",
" global x\n",
" x = a + 1 # Bad - leaks local tensor\n",
" return a + 2\n",
"\n",
"correct_a = leaky_function(tf.constant(1))\n",
"\n",
"print(correct_a.numpy()) # Good - value obtained from function's returns\n",
"try:\n",
" x.numpy() # Bad - tensor leaked from inside the function, cannot be used here\n",
"except AttributeError as expected:\n",
" print(expected)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-d4_J_DC5rxX"
},
"source": [
"This is true even if the leaked value is also returned:"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:36.960206Z",
"iopub.status.busy": "2024-08-15T02:57:36.959973Z",
"iopub.status.idle": "2024-08-15T02:57:37.017142Z",
"shell.execute_reply": "2024-08-15T02:57:37.016451Z"
},
"id": "PrcpPB8C5s9T"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2\n",
"'SymbolicTensor' object has no attribute 'numpy'\n",
"Caught expected exception \n",
" :\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"/tmpfs/tmp/ipykernel_167534/3551158538.py\", line 8, in assert_raises\n",
" yield\n",
" File \"/tmpfs/tmp/ipykernel_167534/566849597.py\", line 21, in \n",
" captures_leaked_tensor(tf.constant(2))\n",
"TypeError: is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.\n",
"Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.\n",
"\n",
" was defined here:\n",
" File \"/usr/lib/python3.9/runpy.py\", line 197, in _run_module_as_main\n",
" File \"/usr/lib/python3.9/runpy.py\", line 87, in _run_code\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py\", line 18, in \n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelapp.py\", line 739, in start\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tornado/platform/asyncio.py\", line 205, in start\n",
" File \"/usr/lib/python3.9/asyncio/base_events.py\", line 601, in run_forever\n",
" File \"/usr/lib/python3.9/asyncio/base_events.py\", line 1905, in _run_once\n",
" File \"/usr/lib/python3.9/asyncio/events.py\", line 80, in _run\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py\", line 545, in dispatch_queue\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py\", line 534, in process_one\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py\", line 437, in dispatch_shell\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py\", line 362, in execute_request\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py\", line 778, in execute_request\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py\", line 449, in do_execute\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 3048, in run_cell\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 3103, in _run_cell\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 3308, in run_cell_async\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 3490, in run_ast_nodes\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py\", line 3550, in run_code\n",
" File \"/tmpfs/tmp/ipykernel_167534/566849597.py\", line 7, in \n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py\", line 150, in error_handler\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py\", line 833, in __call__\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py\", line 889, in _call\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py\", line 696, in _initialize\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py\", line 178, in trace_function\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py\", line 283, in _maybe_define_function\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py\", line 310, in _create_concrete_function\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py\", line 1059, in func_graph_from_py_func\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py\", line 599, in wrapped_fn\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py\", line 41, in autograph_handler\n",
" File \"/tmpfs/tmp/ipykernel_167534/566849597.py\", line 4, in leaky_function\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py\", line 150, in error_handler\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/override_binary_operator.py\", line 113, in binary_op_wrapper\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/tensor_math_operator_overrides.py\", line 28, in _add_dispatch_factory\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py\", line 150, in error_handler\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py\", line 1260, in op_dispatch_handler\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py\", line 1701, in _add_dispatch\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py\", line 490, in add_v2\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py\", line 796, in _apply_op_helper\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py\", line 670, in _create_op_internal\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py\", line 2682, in _create_op_internal\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py\", line 1177, in from_node_def\n",
"\n",
"The tensor cannot be accessed from here, because it was defined in FuncGraph(name=leaky_function, id=139959630636096), which is out of scope.\n"
]
}
],
"source": [
"@tf.function\n",
"def leaky_function(a):\n",
" global x\n",
" x = a + 1 # Bad - leaks local tensor\n",
" return x # Good - uses local tensor\n",
"\n",
"correct_a = leaky_function(tf.constant(1))\n",
"\n",
"print(correct_a.numpy()) # Good - value obtained from function's returns\n",
"try:\n",
" x.numpy() # Bad - tensor leaked from inside the function, cannot be used here\n",
"except AttributeError as expected:\n",
" print(expected)\n",
"\n",
"@tf.function\n",
"def captures_leaked_tensor(b):\n",
" b += x # Bad - `x` is leaked from `leaky_function`\n",
" return b\n",
"\n",
"with assert_raises(TypeError):\n",
" captures_leaked_tensor(tf.constant(2))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Sm2ghjyy50D4"
},
"source": [
"Usually, leaks such as these occur when you use Python statements or data structures.\n",
"In addition to leaking inaccessible tensors, such statements are also likely wrong because they count as Python side effects, and are not guaranteed to execute at every function call.\n",
"\n",
"Common ways to leak local tensors also include mutating an external Python collection, or an object:"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:37.020490Z",
"iopub.status.busy": "2024-08-15T02:57:37.020255Z",
"iopub.status.idle": "2024-08-15T02:57:37.024455Z",
"shell.execute_reply": "2024-08-15T02:57:37.023760Z"
},
"id": "D7bLe8y652wU"
},
"outputs": [],
"source": [
"class MyClass:\n",
"\n",
" def __init__(self):\n",
" self.field = None\n",
"\n",
"external_list = []\n",
"external_object = MyClass()\n",
"\n",
"def leaky_function():\n",
" a = tf.constant(1)\n",
" external_list.append(a) # Bad - leaks tensor\n",
" external_object.field = a # Bad - leaks tensor"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "g-XVQcD-wf5K"
},
"source": [
"### Recursive tf.functions are not supported\n",
"\n",
"Recursive `tf.function`s are not supported and could cause infinite loops. For example,"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:37.027689Z",
"iopub.status.busy": "2024-08-15T02:57:37.027443Z",
"iopub.status.idle": "2024-08-15T02:57:37.896353Z",
"shell.execute_reply": "2024-08-15T02:57:37.895551Z"
},
"id": "QSN-T1m5EFcR"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Caught expected exception \n",
" :\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"/tmpfs/tmp/ipykernel_167534/3551158538.py\", line 8, in assert_raises\n",
" yield\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 9, in \n",
" recursive_fn(tf.constant(5)) # Bad - maximum recursion error.\n",
"tensorflow.python.autograph.impl.api.StagingError: in user code:\n",
"\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
" File \"/tmpfs/tmp/ipykernel_167534/2233998312.py\", line 4, in recursive_fn *\n",
" return recursive_fn(n - 1)\n",
"\n",
" RecursionError: maximum recursion depth exceeded while calling a Python object\n",
"\n"
]
}
],
"source": [
"@tf.function\n",
"def recursive_fn(n):\n",
" if n > 0:\n",
" return recursive_fn(n - 1)\n",
" else:\n",
" return 1\n",
"\n",
"with assert_raises(Exception):\n",
" recursive_fn(tf.constant(5)) # Bad - maximum recursion error."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LyRyooKGUxNV"
},
"source": [
"Even if a recursive `tf.function` seems to work, the Python function will be traced multiple times and could have performance implications. For example,"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:37.900614Z",
"iopub.status.busy": "2024-08-15T02:57:37.900366Z",
"iopub.status.idle": "2024-08-15T02:57:37.968993Z",
"shell.execute_reply": "2024-08-15T02:57:37.968397Z"
},
"id": "7FlmTqfMUwmT"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tracing\n",
"tracing\n",
"tracing\n",
"tracing\n",
"tracing\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@tf.function\n",
"def recursive_fn(n):\n",
" if n > 0:\n",
" print('tracing')\n",
" return recursive_fn(n - 1)\n",
" else:\n",
" return 1\n",
"\n",
"recursive_fn(5) # Warning - multiple tracings"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-D6nh3QirXAd"
},
"source": [
"## Known Issues\n",
"\n",
"If your `tf.function` is not evaluating correctly, the error may be explained by these known issues which are planned to be fixed in the future."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZoPg5w1Pjqna"
},
"source": [
"### Depending on Python global and free variables\n",
"\n",
"`tf.function` creates a new `ConcreteFunction` when called with a new value of a Python argument. However, it does not do that for the Python closure, globals, or nonlocals of that `tf.function`. If their value changes in between calls to the `tf.function`, the `tf.function` will still use the values they had when it was traced. This is different from how regular Python functions work.\n",
"\n",
"For that reason, you should follow a functional programming style that uses arguments instead of closing over outer names."
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:37.972525Z",
"iopub.status.busy": "2024-08-15T02:57:37.972266Z",
"iopub.status.idle": "2024-08-15T02:57:38.031488Z",
"shell.execute_reply": "2024-08-15T02:57:38.030861Z"
},
"id": "oeJMdXd3M0cM"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Buggy: tf.Tensor(2, shape=(), dtype=int32)\n",
"Correct: tf.Tensor(2, shape=(), dtype=int32)\n"
]
}
],
"source": [
"@tf.function\n",
"def buggy_add():\n",
" return 1 + foo\n",
"\n",
"@tf.function\n",
"def recommended_add(foo):\n",
" return 1 + foo\n",
"\n",
"foo = 1\n",
"print(\"Buggy:\", buggy_add())\n",
"print(\"Correct:\", recommended_add(foo))"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:38.035018Z",
"iopub.status.busy": "2024-08-15T02:57:38.034595Z",
"iopub.status.idle": "2024-08-15T02:57:38.051498Z",
"shell.execute_reply": "2024-08-15T02:57:38.050854Z"
},
"id": "L3q7sUJWZOSU"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Updating the value of `foo` to 100!\n",
"Buggy: tf.Tensor(2, shape=(), dtype=int32)\n",
"Correct: tf.Tensor(101, shape=(), dtype=int32)\n"
]
}
],
"source": [
"print(\"Updating the value of `foo` to 100!\")\n",
"foo = 100\n",
"print(\"Buggy:\", buggy_add()) # Did not change!\n",
"print(\"Correct:\", recommended_add(foo))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZoPg5w1Pjqnb"
},
"source": [
"Another way to update a global value is to make it a `tf.Variable` and use the `Variable.assign` method instead.\n"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:38.054589Z",
"iopub.status.busy": "2024-08-15T02:57:38.054357Z",
"iopub.status.idle": "2024-08-15T02:57:38.090878Z",
"shell.execute_reply": "2024-08-15T02:57:38.090317Z"
},
"id": "oeJMdXd3M0cc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Variable: tf.Tensor(2, shape=(), dtype=int32)\n"
]
}
],
"source": [
"@tf.function\n",
"def variable_add():\n",
" return 1 + foo\n",
"\n",
"foo = tf.Variable(1)\n",
"print(\"Variable:\", variable_add())\n"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:38.093796Z",
"iopub.status.busy": "2024-08-15T02:57:38.093553Z",
"iopub.status.idle": "2024-08-15T02:57:38.098419Z",
"shell.execute_reply": "2024-08-15T02:57:38.097783Z"
},
"id": "L3q7sUJWZOSd"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Updating the value of `foo` to 100!\n",
"Variable: tf.Tensor(101, shape=(), dtype=int32)\n"
]
}
],
"source": [
"print(\"Updating the value of `foo` to 100!\")\n",
"foo.assign(100)\n",
"print(\"Variable:\", variable_add())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hvwe9gTIWfx6"
},
"source": [
"### Depending on Python objects"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BJkZS-SwPvOQ"
},
"source": [
"Passing custom Python objects as arguments to `tf.function` is supported but has certain limitations.\n",
"\n",
"For maximum feature coverage, consider transforming the objects into [Extension types](extension_type.ipynb) before passing them to `tf.function`. You can also use Python primitives and `tf.nest`-compatible structures.\n",
"\n",
"However, as covered in the [rules of tracing](#rules_of_tracing), when a custom `TraceType` is not provided by the custom Python class, `tf.function` is forced to use instance-based equality which means it will **not create a new trace** when you pass the **same object with modified attributes**."
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:38.101888Z",
"iopub.status.busy": "2024-08-15T02:57:38.101418Z",
"iopub.status.idle": "2024-08-15T02:57:38.141416Z",
"shell.execute_reply": "2024-08-15T02:57:38.140811Z"
},
"id": "ux8KJESVWDxX"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(20.0, shape=(), dtype=float32)\n"
]
}
],
"source": [
"class SimpleModel(tf.Module):\n",
" def __init__(self):\n",
" # These values are *not* tf.Variables.\n",
" self.bias = 0.\n",
" self.weight = 2.\n",
"\n",
"@tf.function\n",
"def evaluate(model, x):\n",
" return model.weight * x + model.bias\n",
"\n",
"simple_model = SimpleModel()\n",
"x = tf.constant(10.)\n",
"print(evaluate(simple_model, x))"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:38.144244Z",
"iopub.status.busy": "2024-08-15T02:57:38.143979Z",
"iopub.status.idle": "2024-08-15T02:57:38.148502Z",
"shell.execute_reply": "2024-08-15T02:57:38.147830Z"
},
"id": "mUxRF4ghZZvX"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Adding bias!\n",
"tf.Tensor(20.0, shape=(), dtype=float32)\n"
]
}
],
"source": [
"print(\"Adding bias!\")\n",
"simple_model.bias += 5.0\n",
"print(evaluate(simple_model, x)) # Didn't change :("
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ytcgg2qFWaBF"
},
"source": [
"Using the same `tf.function` to evaluate the modified instance of the model will be buggy since it still has the [same instance-based TraceType](#rules_of_tracing) as the original model.\n",
"\n",
"For that reason, you're recommended to write your `tf.function` to avoid depending on mutable object attributes or implement the [Tracing Protocol](#use_the_tracing_protocol) for the objects to inform `tf.function` about such attributes.\n",
"\n",
"If that is not possible, one workaround is to make new `tf.function`s each time you modify your object to force retracing:"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:38.151669Z",
"iopub.status.busy": "2024-08-15T02:57:38.151432Z",
"iopub.status.idle": "2024-08-15T02:57:38.191499Z",
"shell.execute_reply": "2024-08-15T02:57:38.190873Z"
},
"id": "pFvWmWAAQjrv"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(20.0, shape=(), dtype=float32)\n"
]
}
],
"source": [
"def evaluate(model, x):\n",
" return model.weight * x + model.bias\n",
"\n",
"new_model = SimpleModel()\n",
"evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)\n",
"# Don't pass in `new_model`. `tf.function` already captured its state during tracing.\n",
"print(evaluate_no_bias(x))"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:38.194402Z",
"iopub.status.busy": "2024-08-15T02:57:38.194171Z",
"iopub.status.idle": "2024-08-15T02:57:38.216428Z",
"shell.execute_reply": "2024-08-15T02:57:38.215709Z"
},
"id": "bdU2-jF4ZH0B"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Adding bias!\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(25.0, shape=(), dtype=float32)\n"
]
}
],
"source": [
"print(\"Adding bias!\")\n",
"new_model.bias += 5.0\n",
"# Create new `tf.function` and `ConcreteFunction` since you modified `new_model`.\n",
"evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)\n",
"print(evaluate_with_bias(x)) # Don't pass in `new_model`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uFgEZClsZrEi"
},
"source": [
"As [retracing can be expensive](https://www.tensorflow.org/guide/intro_to_graphs#tracing_and_performance), you can use `tf.Variable`s as object attributes, which can be mutated (but not changed, careful!) for a similar effect without needing a retrace.\n"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:38.219766Z",
"iopub.status.busy": "2024-08-15T02:57:38.219519Z",
"iopub.status.idle": "2024-08-15T02:57:38.263280Z",
"shell.execute_reply": "2024-08-15T02:57:38.262594Z"
},
"id": "daAP_lucwS6w"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(20.0, shape=(), dtype=float32)\n"
]
}
],
"source": [
"class BetterModel:\n",
"\n",
" def __init__(self):\n",
" self.bias = tf.Variable(0.)\n",
" self.weight = tf.Variable(2.)\n",
"\n",
"@tf.function\n",
"def evaluate(model, x):\n",
" return model.weight * x + model.bias\n",
"\n",
"better_model = BetterModel()\n",
"print(evaluate(better_model, x))\n"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:38.266265Z",
"iopub.status.busy": "2024-08-15T02:57:38.265998Z",
"iopub.status.idle": "2024-08-15T02:57:38.272572Z",
"shell.execute_reply": "2024-08-15T02:57:38.271976Z"
},
"id": "ktqwMJBqwTFj"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Adding bias!\n",
"tf.Tensor(25.0, shape=(), dtype=float32)\n"
]
}
],
"source": [
"print(\"Adding bias!\")\n",
"better_model.bias.assign_add(5.0) # Note: instead of better_model.bias += 5\n",
"print(evaluate(better_model, x)) # This works!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lPr_6mK_AQWL"
},
"source": [
"### Creating tf.Variables\n",
"\n",
"`tf.function` only supports singleton `tf.Variable`s created once on the first call, and reused across subsequent function calls. The code snippet below would create a new `tf.Variable` in every function call, which results in a `ValueError` exception.\n",
"\n",
"Example:"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:38.275962Z",
"iopub.status.busy": "2024-08-15T02:57:38.275428Z",
"iopub.status.idle": "2024-08-15T02:57:38.329174Z",
"shell.execute_reply": "2024-08-15T02:57:38.328570Z"
},
"id": "Tx0Vvnb_9OB-"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Caught expected exception \n",
" :\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"/tmpfs/tmp/ipykernel_167534/3551158538.py\", line 8, in assert_raises\n",
" yield\n",
" File \"/tmpfs/tmp/ipykernel_167534/3018268426.py\", line 7, in \n",
" f(1.0)\n",
"ValueError: in user code:\n",
"\n",
" File \"/tmpfs/tmp/ipykernel_167534/3018268426.py\", line 3, in f *\n",
" v = tf.Variable(1.0)\n",
"\n",
" ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.\n",
"\n"
]
}
],
"source": [
"@tf.function\n",
"def f(x):\n",
" v = tf.Variable(1.0)\n",
" return v\n",
"\n",
"with assert_raises(ValueError):\n",
" f(1.0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KYm6-5GCILXQ"
},
"source": [
"A common pattern used to work around this limitation is to start with a Python None value, then conditionally create the `tf.Variable` if the value is None:"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:38.333001Z",
"iopub.status.busy": "2024-08-15T02:57:38.332739Z",
"iopub.status.idle": "2024-08-15T02:57:38.413836Z",
"shell.execute_reply": "2024-08-15T02:57:38.413177Z"
},
"id": "HQrG5_kOiKl_"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(1, shape=(), dtype=int32)\n",
"tf.Tensor(2, shape=(), dtype=int32)\n"
]
}
],
"source": [
"class Count(tf.Module):\n",
" def __init__(self):\n",
" self.count = None\n",
"\n",
" @tf.function\n",
" def __call__(self):\n",
" if self.count is None:\n",
" self.count = tf.Variable(0)\n",
" return self.count.assign_add(1)\n",
"\n",
"c = Count()\n",
"print(c())\n",
"print(c())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7uD6qI7aJwbR"
},
"source": [
"#### Using with multiple Keras optimizers\n",
"You may encounter `ValueError: tf.function only supports singleton tf.Variables created on the first call.` when using more than one Keras optimizer with a `tf.function`. This error occurs because optimizers internally create `tf.Variable`s when they apply gradients for the first time."
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:38.417100Z",
"iopub.status.busy": "2024-08-15T02:57:38.416839Z",
"iopub.status.idle": "2024-08-15T02:57:39.037364Z",
"shell.execute_reply": "2024-08-15T02:57:39.036602Z"
},
"id": "yWQ3-r99Jvze"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Calling `train_step` with different optimizer...\n",
"Caught expected exception \n",
" :\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"/tmpfs/tmp/ipykernel_167534/3551158538.py\", line 8, in assert_raises\n",
" yield\n",
" File \"/tmpfs/tmp/ipykernel_167534/950644149.py\", line 18, in \n",
" train_step(w, x, y, opt2)\n",
"ValueError: in user code:\n",
"\n",
" File \"/tmpfs/tmp/ipykernel_167534/950644149.py\", line 9, in train_step *\n",
" optimizer.apply_gradients(zip(gradients, [w]))\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/base_optimizer.py\", line 291, in apply_gradients **\n",
" self.apply(grads, trainable_variables)\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/base_optimizer.py\", line 330, in apply\n",
" self.build(trainable_variables)\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/adam.py\", line 97, in build\n",
" self.add_variable_from_reference(\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/backend/tensorflow/optimizer.py\", line 36, in add_variable_from_reference\n",
" return super().add_variable_from_reference(\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/base_optimizer.py\", line 227, in add_variable_from_reference\n",
" return self.add_variable(\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/base_optimizer.py\", line 201, in add_variable\n",
" variable = backend.Variable(\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/backend/common/variables.py\", line 163, in __init__\n",
" self._initialize_with_initializer(initializer)\n",
" File \"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/backend/tensorflow/core.py\", line 40, in _initialize_with_initializer\n",
" self._value = tf.Variable(\n",
"\n",
" ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.\n",
"\n"
]
}
],
"source": [
"opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)\n",
"opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)\n",
"\n",
"@tf.function\n",
"def train_step(w, x, y, optimizer):\n",
" with tf.GradientTape() as tape:\n",
" L = tf.reduce_sum(tf.square(w*x - y))\n",
" gradients = tape.gradient(L, [w])\n",
" optimizer.apply_gradients(zip(gradients, [w]))\n",
"\n",
"w = tf.Variable(2.)\n",
"x = tf.constant([-1.])\n",
"y = tf.constant([2.])\n",
"\n",
"train_step(w, x, y, opt1)\n",
"print(\"Calling `train_step` with different optimizer...\")\n",
"with assert_raises(ValueError):\n",
" train_step(w, x, y, opt2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7Q8BRPCThTjB"
},
"source": [
"If you need to change a stateful object between calls, it's simplest to define a `tf.Module` subclass, and create instances to hold those objects:"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:39.041827Z",
"iopub.status.busy": "2024-08-15T02:57:39.041162Z",
"iopub.status.idle": "2024-08-15T02:57:39.375520Z",
"shell.execute_reply": "2024-08-15T02:57:39.374855Z"
},
"id": "3P59ocmIslHz"
},
"outputs": [],
"source": [
"class TrainStep(tf.Module):\n",
" def __init__(self, optimizer):\n",
" self.optimizer = optimizer\n",
"\n",
" @tf.function\n",
" def __call__(self, w, x, y):\n",
" with tf.GradientTape() as tape:\n",
" L = tf.reduce_sum(tf.square(w*x - y))\n",
" gradients = tape.gradient(L, [w])\n",
" self.optimizer.apply_gradients(zip(gradients, [w]))\n",
"\n",
"\n",
"opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)\n",
"opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)\n",
"\n",
"train_o1 = TrainStep(opt1)\n",
"train_o2 = TrainStep(opt2)\n",
"\n",
"train_o1(w, x, y)\n",
"train_o2(w, x, y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dUHUi881smHF"
},
"source": [
"You could also do this manually by creating multiple instances of the `@tf.function` wrapper, one for each optimizer:"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T02:57:39.379288Z",
"iopub.status.busy": "2024-08-15T02:57:39.379046Z",
"iopub.status.idle": "2024-08-15T02:57:39.713857Z",
"shell.execute_reply": "2024-08-15T02:57:39.713166Z"
},
"id": "YV5F2Gy9hSI3"
},
"outputs": [],
"source": [
"opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)\n",
"opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)\n",
"\n",
"# Not a tf.function.\n",
"def train_step(w, x, y, optimizer):\n",
" with tf.GradientTape() as tape:\n",
" L = tf.reduce_sum(tf.square(w*x - y))\n",
" gradients = tape.gradient(L, [w])\n",
" optimizer.apply_gradients(zip(gradients, [w]))\n",
"\n",
"w = tf.Variable(2.)\n",
"x = tf.constant([-1.])\n",
"y = tf.constant([2.])\n",
"\n",
"# Make a new tf.function and ConcreteFunction for each optimizer.\n",
"train_step_1 = tf.function(train_step)\n",
"train_step_2 = tf.function(train_step)\n",
"for i in range(10):\n",
" if i % 2 == 0:\n",
" train_step_1(w, x, y, opt1)\n",
" else:\n",
" train_step_2(w, x, y, opt2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Xjnz5CcuqQac"
},
"source": [
"#### Using with multiple Keras models\n",
"\n",
"You may also encounter `ValueError: tf.function only supports singleton tf.Variables created on the first call.` when passing different model instances to the same `tf.function`.\n",
"\n",
"This error occurs because Keras models (which [do not have their input shape defined](https://www.tensorflow.org/guide/keras/custom_layers_and_models#best_practice_deferring_weight_creation_until_the_shape_of_the_inputs_is_known)) and Keras layers create `tf.Variable`s when they are first called. You may be attempting to initialize those variables inside a `tf.function`, which has already been called. To avoid this error, try calling `model.build(input_shape)` to initialize all the weights before training the model.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IKyrEY5GVX3M"
},
"source": [
"## Further reading\n",
"\n",
"To learn about how to export and load a `tf.function`, see the [SavedModel guide](../../guide/saved_model). To learn more about graph optimizations that are performed after tracing, see the [Grappler guide](../../guide/graph_optimization). To learn how to optimize your data pipeline and profile your model, see the [Profiler guide](../../guide/profiler.md)."
]
}
],
"metadata": {
"colab": {
"name": "function.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 0
}