HTTP/2 200
content-type: application/octet-stream
x-guploader-uploadid: ABgVH88n_IHZ2rbwHt_GPXgesz4rynQQuW5vcefw56sTTmabsX9wbv3NgNO2Ug2IBScj6Hf2
expires: Wed, 16 Jul 2025 00:47:11 GMT
date: Tue, 15 Jul 2025 23:47:11 GMT
cache-control: public, max-age=3600
last-modified: Thu, 15 Aug 2024 03:16:01 GMT
etag: "434b58da3315bcb2a1d8f65384b646fb"
x-goog-generation: 1723691761812280
x-goog-metageneration: 1
x-goog-stored-content-encoding: identity
x-goog-stored-content-length: 66733
x-goog-hash: crc32c=Nz1c6w==
x-goog-hash: md5=Q0tY2jMVvLKh2PZThLZG+w==
x-goog-storage-class: MULTI_REGIONAL
accept-ranges: bytes
content-length: 66733
server: UploadServer
alt-svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "N7ITxKLUkX0v"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cellView": "form",
"execution": {
"iopub.execute_input": "2024-08-15T01:23:58.256165Z",
"iopub.status.busy": "2024-08-15T01:23:58.255929Z",
"iopub.status.idle": "2024-08-15T01:23:58.259742Z",
"shell.execute_reply": "2024-08-15T01:23:58.259099Z"
},
"id": "yOYx6tzSnWQ3"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6xgB0Oz5eGSQ"
},
"source": [
"# Introduction to graphs and tf.function"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w4zzZVZtQb1w"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RBKqnXI9GOax"
},
"source": [
"## Overview\n",
"\n",
"This guide goes beneath the surface of TensorFlow and Keras to demonstrate how TensorFlow works. If you instead want to immediately get started with Keras, check out the [collection of Keras guides](https://www.tensorflow.org/guide/keras/).\n",
"\n",
"In this guide, you'll learn how TensorFlow allows you to make simple changes to your code to get graphs, how graphs are stored and represented, and how you can use them to accelerate your models.\n",
"\n",
"Note: For those of you who are only familiar with TensorFlow 1.x, this guide demonstrates a very different view of graphs.\n",
"\n",
"**This is a big-picture overview that covers how `tf.function` allows you to switch from eager execution to graph execution.** For a more complete specification of `tf.function`, go to the [Better performance with `tf.function`](./function.ipynb) guide.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v0DdlfacAdTZ"
},
"source": [
"### What are graphs?\n",
"\n",
"In the previous three guides, you ran TensorFlow **eagerly**. This means TensorFlow operations are executed by Python, operation by operation, and return results back to Python.\n",
"\n",
"While eager execution has several unique advantages, graph execution enables portability outside Python and tends to offer better performance. **Graph execution** means that tensor computations are executed as a *TensorFlow graph*, sometimes referred to as a `tf.Graph` or simply a \"graph.\"\n",
"\n",
"**Graphs are data structures that contain a set of `tf.Operation` objects, which represent units of computation; and `tf.Tensor` objects, which represent the units of data that flow between operations.** They are defined in a `tf.Graph` context. Since these graphs are data structures, they can be saved, run, and restored all without the original Python code.\n",
"\n",
"This is what a TensorFlow graph representing a two-layer neural network looks like when visualized in TensorBoard:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FvQ5aBuRGT1o"
},
"source": [
"

"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DHpY3avXGITP"
},
"source": [
"### The benefits of graphs\n",
"\n",
"With a graph, you have a great deal of flexibility. You can use your TensorFlow graph in environments that don't have a Python interpreter, like mobile applications, embedded devices, and backend servers. TensorFlow uses graphs as the format for [saved models](./saved_model.ipynb) when it exports them from Python.\n",
"\n",
"Graphs are also easily optimized, allowing the compiler to do transformations like:\n",
"\n",
"* Statically infer the value of tensors by folding constant nodes in your computation *(\"constant folding\")*.\n",
"* Separate sub-parts of a computation that are independent and split them between threads or devices.\n",
"* Simplify arithmetic operations by eliminating common subexpressions.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o1x1EOD9GjnB"
},
"source": [
"There is an entire optimization system, [Grappler](./graph_optimization.ipynb), to perform this and other speedups.\n",
"\n",
"In short, graphs are extremely useful and let your TensorFlow run **fast**, run **in parallel**, and run efficiently **on multiple devices**.\n",
"\n",
"However, you still want to define your machine learning models (or other computations) in Python for convenience, and then automatically construct graphs when you need them."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k-6Qi0thw2i9"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0d1689fa928f"
},
"source": [
"Import some necessary libraries:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:23:58.263617Z",
"iopub.status.busy": "2024-08-15T01:23:58.263391Z",
"iopub.status.idle": "2024-08-15T01:24:00.577791Z",
"shell.execute_reply": "2024-08-15T01:24:00.577110Z"
},
"id": "goZwOXp_xyQj"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-08-15 01:23:58.511668: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-08-15 01:23:58.532403: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-08-15 01:23:58.538519: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"import timeit\n",
"from datetime import datetime"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pSZebVuWxDXu"
},
"source": [
"## Taking advantage of graphs\n",
"\n",
"You create and run a graph in TensorFlow by using `tf.function`, either as a direct call or as a decorator. `tf.function` takes a regular function as input and returns a `tf.types.experimental.PolymorphicFunction`. **A `PolymorphicFunction` is a Python callable that builds TensorFlow graphs from the Python function. You use a `tf.function` in the same way as its Python equivalent.**\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:00.582158Z",
"iopub.status.busy": "2024-08-15T01:24:00.581775Z",
"iopub.status.idle": "2024-08-15T01:24:02.915853Z",
"shell.execute_reply": "2024-08-15T01:24:02.915156Z"
},
"id": "HKbLeJ1y0Umi"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
"I0000 00:00:1723685041.078349 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685041.081709 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685041.084876 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685041.088691 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685041.100124 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685041.103158 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685041.106072 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685041.109491 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685041.112991 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685041.115870 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685041.118785 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685041.122189 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.369900 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.372045 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.374040 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.376123 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.378174 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.380184 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.382098 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.384064 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.386002 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.387981 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.389902 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.391922 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.431010 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.433093 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.435050 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.437074 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.439053 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.441049 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.442965 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.444941 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.446890 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.450623 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.453482 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS ha"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"d negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n",
"I0000 00:00:1723685042.455908 10585 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n"
]
}
],
"source": [
"# Define a Python function.\n",
"def a_regular_function(x, y, b):\n",
" x = tf.matmul(x, y)\n",
" x = x + b\n",
" return x\n",
"\n",
"# The Python type of `a_function_that_uses_a_graph` will now be a\n",
"# `PolymorphicFunction`.\n",
"a_function_that_uses_a_graph = tf.function(a_regular_function)\n",
"\n",
"# Make some tensors.\n",
"x1 = tf.constant([[1.0, 2.0]])\n",
"y1 = tf.constant([[2.0], [3.0]])\n",
"b1 = tf.constant(4.0)\n",
"\n",
"orig_value = a_regular_function(x1, y1, b1).numpy()\n",
"# Call a `tf.function` like a Python function.\n",
"tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy()\n",
"assert(orig_value == tf_function_value)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PNvuAYpdrTOf"
},
"source": [
"On the outside, a `tf.function` looks like a regular function you write using TensorFlow operations. [Underneath](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/eager/polymorphic_function/polymorphic_function.py), however, it is *very different*. The underlying `PolymorphicFunction` **encapsulates several `tf.Graph`s behind one API** (learn more in the _Polymorphism_ section). That is how a `tf.function` is able to give you the benefits of graph execution, like speed and deployability (refer to _The benefits of graphs_ above)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MT7U8ozok0gV"
},
"source": [
"`tf.function` applies to a function *and all other functions it calls*:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:02.919684Z",
"iopub.status.busy": "2024-08-15T01:24:02.918961Z",
"iopub.status.idle": "2024-08-15T01:24:02.991985Z",
"shell.execute_reply": "2024-08-15T01:24:02.991411Z"
},
"id": "rpz08iLplm9F"
},
"outputs": [
{
"data": {
"text/plain": [
"array([[12.]], dtype=float32)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def inner_function(x, y, b):\n",
" x = tf.matmul(x, y)\n",
" x = x + b\n",
" return x\n",
"\n",
"# Using the `tf.function` decorator makes `outer_function` into a\n",
"# `PolymorphicFunction`.\n",
"@tf.function\n",
"def outer_function(x):\n",
" y = tf.constant([[2.0], [3.0]])\n",
" b = tf.constant(4.0)\n",
"\n",
" return inner_function(x, y, b)\n",
"\n",
"# Note that the callable will create a graph that\n",
"# includes `inner_function` as well as `outer_function`.\n",
"outer_function(tf.constant([[1.0, 2.0]])).numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P88fOr88qgCj"
},
"source": [
"If you have used TensorFlow 1.x, you will notice that at no time did you need to define a `Placeholder` or `tf.Session`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wfeKf0Nr1OEK"
},
"source": [
"### Converting Python functions to graphs\n",
"\n",
"Any function you write with TensorFlow will contain a mixture of built-in TF operations and Python logic, such as `if-then` clauses, loops, `break`, `return`, `continue`, and more. While TensorFlow operations are easily captured by a `tf.Graph`, Python-specific logic needs to undergo an extra step in order to become part of the graph. `tf.function` uses a library called AutoGraph (`tf.autograph`) to convert Python code into graph-generating code.\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:02.995143Z",
"iopub.status.busy": "2024-08-15T01:24:02.994907Z",
"iopub.status.idle": "2024-08-15T01:24:03.057871Z",
"shell.execute_reply": "2024-08-15T01:24:03.057292Z"
},
"id": "PFObpff1BMEb"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"First branch, with graph: 1\n",
"Second branch, with graph: 0\n"
]
}
],
"source": [
"def simple_relu(x):\n",
" if tf.greater(x, 0):\n",
" return x\n",
" else:\n",
" return 0\n",
"\n",
"# Using `tf.function` makes `tf_simple_relu` a `PolymorphicFunction` that wraps\n",
"# `simple_relu`.\n",
"tf_simple_relu = tf.function(simple_relu)\n",
"\n",
"print(\"First branch, with graph:\", tf_simple_relu(tf.constant(1)).numpy())\n",
"print(\"Second branch, with graph:\", tf_simple_relu(tf.constant(-1)).numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hO4DBUNZBMwQ"
},
"source": [
"Though it is unlikely that you will need to view graphs directly, you can inspect the outputs to check the exact results. These are not easy to read, so no need to look too carefully!"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.060937Z",
"iopub.status.busy": "2024-08-15T01:24:03.060713Z",
"iopub.status.idle": "2024-08-15T01:24:03.065563Z",
"shell.execute_reply": "2024-08-15T01:24:03.064963Z"
},
"id": "lAKaat3w0gnn"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def tf__simple_relu(x):\n",
" with ag__.FunctionScope('simple_relu', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:\n",
" do_return = False\n",
" retval_ = ag__.UndefinedReturnValue()\n",
"\n",
" def get_state():\n",
" return (do_return, retval_)\n",
"\n",
" def set_state(vars_):\n",
" nonlocal do_return, retval_\n",
" (do_return, retval_) = vars_\n",
"\n",
" def if_body():\n",
" nonlocal do_return, retval_\n",
" try:\n",
" do_return = True\n",
" retval_ = ag__.ld(x)\n",
" except:\n",
" do_return = False\n",
" raise\n",
"\n",
" def else_body():\n",
" nonlocal do_return, retval_\n",
" try:\n",
" do_return = True\n",
" retval_ = 0\n",
" except:\n",
" do_return = False\n",
" raise\n",
" ag__.if_stmt(ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(x), 0), None, fscope), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)\n",
" return fscope.ret(retval_, do_return)\n",
"\n"
]
}
],
"source": [
"# This is the graph-generating output of AutoGraph.\n",
"print(tf.autograph.to_code(simple_relu))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.068332Z",
"iopub.status.busy": "2024-08-15T01:24:03.068117Z",
"iopub.status.idle": "2024-08-15T01:24:03.072311Z",
"shell.execute_reply": "2024-08-15T01:24:03.071765Z"
},
"id": "8x6RAqza1UWf"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"node {\n",
" name: \"x\"\n",
" op: \"Placeholder\"\n",
" attr {\n",
" key: \"_user_specified_name\"\n",
" value {\n",
" s: \"x\"\n",
" }\n",
" }\n",
" attr {\n",
" key: \"dtype\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" attr {\n",
" key: \"shape\"\n",
" value {\n",
" shape {\n",
" }\n",
" }\n",
" }\n",
"}\n",
"node {\n",
" name: \"Greater/y\"\n",
" op: \"Const\"\n",
" attr {\n",
" key: \"dtype\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" attr {\n",
" key: \"value\"\n",
" value {\n",
" tensor {\n",
" dtype: DT_INT32\n",
" tensor_shape {\n",
" }\n",
" int_val: 0\n",
" }\n",
" }\n",
" }\n",
"}\n",
"node {\n",
" name: \"Greater\"\n",
" op: \"Greater\"\n",
" input: \"x\"\n",
" input: \"Greater/y\"\n",
" attr {\n",
" key: \"T\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
"}\n",
"node {\n",
" name: \"cond\"\n",
" op: \"StatelessIf\"\n",
" input: \"Greater\"\n",
" input: \"x\"\n",
" attr {\n",
" key: \"Tcond\"\n",
" value {\n",
" type: DT_BOOL\n",
" }\n",
" }\n",
" attr {\n",
" key: \"Tin\"\n",
" value {\n",
" list {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"Tout\"\n",
" value {\n",
" list {\n",
" type: DT_BOOL\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"_lower_using_switch_merge\"\n",
" value {\n",
" b: true\n",
" }\n",
" }\n",
" attr {\n",
" key: \"_read_only_resource_inputs\"\n",
" value {\n",
" list {\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"else_branch\"\n",
" value {\n",
" func {\n",
" name: \"cond_false_31\"\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"output_shapes\"\n",
" value {\n",
" list {\n",
" shape {\n",
" }\n",
" shape {\n",
" }\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"then_branch\"\n",
" value {\n",
" func {\n",
" name: \"cond_true_30\"\n",
" }\n",
" }\n",
" }\n",
"}\n",
"node {\n",
" name: \"cond/Identity\"\n",
" op: \"Identity\"\n",
" input: \"cond\"\n",
" attr {\n",
" key: \"T\"\n",
" value {\n",
" type: DT_BOOL\n",
" }\n",
" }\n",
"}\n",
"node {\n",
" name: \"cond/Identity_1\"\n",
" op: \"Identity\"\n",
" input: \"cond:1\"\n",
" attr {\n",
" key: \"T\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
"}\n",
"node {\n",
" name: \"Identity\"\n",
" op: \"Identity\"\n",
" input: \"cond/Identity_1\"\n",
" attr {\n",
" key: \"T\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
"}\n",
"library {\n",
" function {\n",
" signature {\n",
" name: \"cond_false_31\"\n",
" input_arg {\n",
" name: \"cond_placeholder\"\n",
" type: DT_INT32\n",
" }\n",
" output_arg {\n",
" name: \"cond_identity\"\n",
" type: DT_BOOL\n",
" }\n",
" output_arg {\n",
" name: \"cond_identity_1\"\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Const\"\n",
" op: \"Const\"\n",
" attr {\n",
" key: \"dtype\"\n",
" value {\n",
" type: DT_BOOL\n",
" }\n",
" }\n",
" attr {\n",
" key: \"value\"\n",
" value {\n",
" tensor {\n",
" dtype: DT_BOOL\n",
" tensor_shape {\n",
" }\n",
" bool_val: true\n",
" }\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Const_1\"\n",
" op: \"Const\"\n",
" attr {\n",
" key: \"dtype\"\n",
" value {\n",
" type: DT_BOOL\n",
" }\n",
" }\n",
" attr {\n",
" key: \"value\"\n",
" value {\n",
" tensor {\n",
" dtype: DT_BOOL\n",
" tensor_shape {\n",
" }\n",
" bool_val: true\n",
" }\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Const_2\"\n",
" op: \"Const\"\n",
" attr {\n",
" key: \"dtype\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" attr {\n",
" key: \"value\"\n",
" value {\n",
" tensor {\n",
" dtype: DT_INT32\n",
" tensor_shape {\n",
" }\n",
" int_val: 0\n",
" }\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Const_3\"\n",
" op: \"Const\"\n",
" attr {\n",
" key: \"dtype\"\n",
" value {\n",
" type: DT_BOOL\n",
" }\n",
" }\n",
" attr {\n",
" key: \"value\"\n",
" value {\n",
" tensor {\n",
" dtype: DT_BOOL\n",
" tensor_shape {\n",
" }\n",
" bool_val: true\n",
" }\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Identity\"\n",
" op: \"Identity\"\n",
" input: \"cond/Const_3:output:0\"\n",
" attr {\n",
" key: \"T\"\n",
" value {\n",
" type: DT_BOOL\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Const_4\"\n",
" op: \"Const\"\n",
" attr {\n",
" key: \"dtype\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" attr {\n",
" key: \"value\"\n",
" value {\n",
" tensor {\n",
" dtype: DT_INT32\n",
" tensor_shape {\n",
" }\n",
" int_val: 0\n",
" }\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Identity_1\"\n",
" op: \"Identity\"\n",
" input: \"cond/Const_4:output:0\"\n",
" attr {\n",
" key: \"T\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" }\n",
" ret {\n",
" key: \"cond_identity\"\n",
" value: \"cond/Identity:output:0\"\n",
" }\n",
" ret {\n",
" key: \"cond_identity_1\"\n",
" value: \"cond/Identity_1:output:0\"\n",
" }\n",
" attr {\n",
" key: \"_construction_context\"\n",
" value {\n",
" s: \"kEagerRuntime\"\n",
" }\n",
" }\n",
" arg_attr {\n",
" key: 0\n",
" value {\n",
" attr {\n",
" key: \"_output_shapes\"\n",
" value {\n",
" list {\n",
" shape {\n",
" }\n",
" }\n",
" }\n",
" }\n",
" }\n",
" }\n",
" }\n",
" function {\n",
" signature {\n",
" name: \"cond_true_30\"\n",
" input_arg {\n",
" name: \"cond_identity_1_x\"\n",
" type: DT_INT32\n",
" }\n",
" output_arg {\n",
" name: \"cond_identity\"\n",
" type: DT_BOOL\n",
" }\n",
" output_arg {\n",
" name: \"cond_identity_1\"\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Const\"\n",
" op: \"Const\"\n",
" attr {\n",
" key: \"dtype\"\n",
" value {\n",
" type: DT_BOOL\n",
" }\n",
" }\n",
" attr {\n",
" key: \"value\"\n",
" value {\n",
" tensor {\n",
" dtype: DT_BOOL\n",
" tensor_shape {\n",
" }\n",
" bool_val: true\n",
" }\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Identity\"\n",
" op: \"Identity\"\n",
" input: \"cond/Const:output:0\"\n",
" attr {\n",
" key: \"T\"\n",
" value {\n",
" type: DT_BOOL\n",
" }\n",
" }\n",
" }\n",
" node_def {\n",
" name: \"cond/Identity_1\"\n",
" op: \"Identity\"\n",
" input: \"cond_identity_1_x\"\n",
" attr {\n",
" key: \"T\"\n",
" value {\n",
" type: DT_INT32\n",
" }\n",
" }\n",
" }\n",
" ret {\n",
" key: \"cond_identity\"\n",
" value: \"cond/Identity:output:0\"\n",
" }\n",
" ret {\n",
" key: \"cond_identity_1\"\n",
" value: \"cond/Identity_1:output:0\"\n",
" }\n",
" attr {\n",
" key: \"_construction_context\"\n",
" value {\n",
" s: \"kEagerRuntime\"\n",
" }\n",
" }\n",
" arg_attr {\n",
" key: 0\n",
" value {\n",
" attr {\n",
" key: \"_output_shapes\"\n",
" value {\n",
" list {\n",
" shape {\n",
" }\n",
" }\n",
" }\n",
" }\n",
" attr {\n",
" key: \"_user_specified_name\"\n",
" value {\n",
" s: \"x\"\n",
" }\n",
" }\n",
" }\n",
" }\n",
" }\n",
"}\n",
"versions {\n",
" producer: 1882\n",
" min_consumer: 12\n",
"}\n",
"\n"
]
}
],
"source": [
"# This is the graph itself.\n",
"print(tf_simple_relu.get_concrete_function(tf.constant(1)).graph.as_graph_def())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GZ4Ieg6tBE6l"
},
"source": [
"Most of the time, `tf.function` will work without special considerations. However, there are some caveats, and the [`tf.function` guide](./function.ipynb) can help here, as well as the [complete AutoGraph reference](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sIpc_jfjEZEg"
},
"source": [
"### Polymorphism: one `tf.function`, many graphs\n",
"\n",
"A `tf.Graph` is specialized to a specific type of inputs (for example, tensors with a specific [`dtype`](https://www.tensorflow.org/api_docs/python/tf/dtypes/DType) or objects with the same [`id()`](https://docs.python.org/3/library/functions.html#id)).\n",
"\n",
"Each time you invoke a `tf.function` with a set of arguments that can't be handled by any of its existing graphs (such as arguments with new `dtypes` or incompatible shapes), it creates a new `tf.Graph` specialized to those new arguments. The type specification of a `tf.Graph`'s inputs is represented by `tf.types.experimental.FunctionType`, also referred to as the **signature**. For more information regarding when a new `tf.Graph` is generated, how that can be controlled, and how `FunctionType` can be useful, go to the _Rules of tracing_ section of the [Better performance with `tf.function`](./function.ipynb) guide.\n",
"\n",
"The `tf.function` stores the `tf.Graph` corresponding to that signature in a `ConcreteFunction`. **A `ConcreteFunction` can be thought of as a wrapper around a `tf.Graph`.**\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.075759Z",
"iopub.status.busy": "2024-08-15T01:24:03.075108Z",
"iopub.status.idle": "2024-08-15T01:24:03.551868Z",
"shell.execute_reply": "2024-08-15T01:24:03.551164Z"
},
"id": "LOASwhbvIv_T"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(5.5, shape=(), dtype=float32)\n",
"tf.Tensor([1. 0.], shape=(2,), dtype=float32)\n",
"tf.Tensor([3. 0.], shape=(2,), dtype=float32)\n"
]
}
],
"source": [
"@tf.function\n",
"def my_relu(x):\n",
" return tf.maximum(0., x)\n",
"\n",
"# `my_relu` creates new graphs as it observes different input types.\n",
"print(my_relu(tf.constant(5.5)))\n",
"print(my_relu([1, -1]))\n",
"print(my_relu(tf.constant([3., -3.])))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1qRtw7R4KL9X"
},
"source": [
"If the `tf.function` has already been called with the same input types, it does not create a new `tf.Graph`."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.555192Z",
"iopub.status.busy": "2024-08-15T01:24:03.554936Z",
"iopub.status.idle": "2024-08-15T01:24:03.561082Z",
"shell.execute_reply": "2024-08-15T01:24:03.560508Z"
},
"id": "TjjbnL5OKNDP"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(0.0, shape=(), dtype=float32)\n",
"tf.Tensor([0. 1.], shape=(2,), dtype=float32)\n"
]
}
],
"source": [
"# These two calls do *not* create new graphs.\n",
"print(my_relu(tf.constant(-2.5))) # Input type matches `tf.constant(5.5)`.\n",
"print(my_relu(tf.constant([-1., 1.]))) # Input type matches `tf.constant([3., -3.])`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UohRmexhIpvQ"
},
"source": [
"Because it's backed by multiple graphs, a `tf.function` is (as the name \"PolymorphicFunction\" suggests) **polymorphic**. That enables it to support more input types than a single `tf.Graph` could represent, and to optimize each `tf.Graph` for better performance."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.564426Z",
"iopub.status.busy": "2024-08-15T01:24:03.563974Z",
"iopub.status.idle": "2024-08-15T01:24:03.567560Z",
"shell.execute_reply": "2024-08-15T01:24:03.566954Z"
},
"id": "dxzqebDYFmLy"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input Parameters:\n",
" x (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None)\n",
"Output Type:\n",
" TensorSpec(shape=(), dtype=tf.float32, name=None)\n",
"Captures:\n",
" None\n",
"\n",
"Input Parameters:\n",
" x (POSITIONAL_OR_KEYWORD): List[Literal[1], Literal[-1]]\n",
"Output Type:\n",
" TensorSpec(shape=(2,), dtype=tf.float32, name=None)\n",
"Captures:\n",
" None\n",
"\n",
"Input Parameters:\n",
" x (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(2,), dtype=tf.float32, name=None)\n",
"Output Type:\n",
" TensorSpec(shape=(2,), dtype=tf.float32, name=None)\n",
"Captures:\n",
" None\n"
]
}
],
"source": [
"# There are three `ConcreteFunction`s (one for each graph) in `my_relu`.\n",
"# The `ConcreteFunction` also knows the return type and shape!\n",
"print(my_relu.pretty_printed_concrete_signatures())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V11zkxU22XeD"
},
"source": [
"## Using `tf.function`\n",
"\n",
"So far, you've learned how to convert a Python function into a graph simply by using `tf.function` as a decorator or wrapper. But in practice, getting `tf.function` to work correctly can be tricky! In the following sections, you'll learn how you can make your code work as expected with `tf.function`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yp_n0B5-P0RU"
},
"source": [
"### Graph execution vs. eager execution\n",
"\n",
"The code in a `tf.function` can be executed both eagerly and as a graph. By default, `tf.function` executes its code as a graph:\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.570599Z",
"iopub.status.busy": "2024-08-15T01:24:03.570350Z",
"iopub.status.idle": "2024-08-15T01:24:03.573928Z",
"shell.execute_reply": "2024-08-15T01:24:03.573378Z"
},
"id": "_R0BOvBFxqVZ"
},
"outputs": [],
"source": [
"@tf.function\n",
"def get_MSE(y_true, y_pred):\n",
" sq_diff = tf.pow(y_true - y_pred, 2)\n",
" return tf.reduce_mean(sq_diff)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.576755Z",
"iopub.status.busy": "2024-08-15T01:24:03.576520Z",
"iopub.status.idle": "2024-08-15T01:24:03.583546Z",
"shell.execute_reply": "2024-08-15T01:24:03.582938Z"
},
"id": "zikMVPGhmDET"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([2 0 7 2 3], shape=(5,), dtype=int32)\n",
"tf.Tensor([9 9 1 1 5], shape=(5,), dtype=int32)\n"
]
}
],
"source": [
"y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)\n",
"y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)\n",
"print(y_true)\n",
"print(y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.586556Z",
"iopub.status.busy": "2024-08-15T01:24:03.586330Z",
"iopub.status.idle": "2024-08-15T01:24:03.634478Z",
"shell.execute_reply": "2024-08-15T01:24:03.633872Z"
},
"id": "07r08Dh158ft"
},
"outputs": [
{
"data": {
"text/plain": [
"
"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_MSE(y_true, y_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cyZNCRcQorGO"
},
"source": [
"To verify that your `tf.function`'s graph is doing the same computation as its equivalent Python function, you can make it execute eagerly with `tf.config.run_functions_eagerly(True)`. This is a switch that **turns off `tf.function`'s ability to create and run graphs**, instead of executing the code normally."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.637523Z",
"iopub.status.busy": "2024-08-15T01:24:03.637287Z",
"iopub.status.idle": "2024-08-15T01:24:03.640481Z",
"shell.execute_reply": "2024-08-15T01:24:03.639803Z"
},
"id": "lKoF6NjPoI8w"
},
"outputs": [],
"source": [
"tf.config.run_functions_eagerly(True)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.643121Z",
"iopub.status.busy": "2024-08-15T01:24:03.642878Z",
"iopub.status.idle": "2024-08-15T01:24:03.649311Z",
"shell.execute_reply": "2024-08-15T01:24:03.648700Z"
},
"id": "9ZLqTyn0oKeM"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_MSE(y_true, y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.652291Z",
"iopub.status.busy": "2024-08-15T01:24:03.652041Z",
"iopub.status.idle": "2024-08-15T01:24:03.654939Z",
"shell.execute_reply": "2024-08-15T01:24:03.654394Z"
},
"id": "cV7daQW9odn-"
},
"outputs": [],
"source": [
"# Don't forget to set it back when you are done.\n",
"tf.config.run_functions_eagerly(False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DKT3YBsqy0x4"
},
"source": [
"However, `tf.function` can behave differently under graph and eager execution. The Python [`print`](https://docs.python.org/3/library/functions.html#print) function is one example of how these two modes differ. Let's check out what happens when you insert a `print` statement to your function and call it repeatedly."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.658034Z",
"iopub.status.busy": "2024-08-15T01:24:03.657475Z",
"iopub.status.idle": "2024-08-15T01:24:03.661320Z",
"shell.execute_reply": "2024-08-15T01:24:03.660775Z"
},
"id": "BEJeVeBEoGjV"
},
"outputs": [],
"source": [
"@tf.function\n",
"def get_MSE(y_true, y_pred):\n",
" print(\"Calculating MSE!\")\n",
" sq_diff = tf.pow(y_true - y_pred, 2)\n",
" return tf.reduce_mean(sq_diff)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3sWTGwX3BzP1"
},
"source": [
"Observe what is printed:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.664596Z",
"iopub.status.busy": "2024-08-15T01:24:03.663991Z",
"iopub.status.idle": "2024-08-15T01:24:03.712257Z",
"shell.execute_reply": "2024-08-15T01:24:03.711645Z"
},
"id": "3rJIeBg72T9n"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Calculating MSE!\n"
]
}
],
"source": [
"error = get_MSE(y_true, y_pred)\n",
"error = get_MSE(y_true, y_pred)\n",
"error = get_MSE(y_true, y_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WLMXk1uxKQ44"
},
"source": [
"Is the output surprising? **`get_MSE` only printed once even though it was called *three* times.**\n",
"\n",
"To explain, the `print` statement is executed when `tf.function` runs the original code in order to create the graph in a process known as \"tracing\" (refer to the _Tracing_ section of the [`tf.function` guide](./function.ipynb). **Tracing captures the TensorFlow operations into a graph, and `print` is not captured in the graph.** That graph is then executed for all three calls **without ever running the Python code again**.\n",
"\n",
"As a sanity check, let's turn off graph execution to compare:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.715317Z",
"iopub.status.busy": "2024-08-15T01:24:03.715080Z",
"iopub.status.idle": "2024-08-15T01:24:03.717985Z",
"shell.execute_reply": "2024-08-15T01:24:03.717364Z"
},
"id": "oFSxRtcptYpe"
},
"outputs": [],
"source": [
"# Now, globally set everything to run eagerly to force eager execution.\n",
"tf.config.run_functions_eagerly(True)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.720825Z",
"iopub.status.busy": "2024-08-15T01:24:03.720592Z",
"iopub.status.idle": "2024-08-15T01:24:03.725124Z",
"shell.execute_reply": "2024-08-15T01:24:03.724511Z"
},
"id": "qYxrAtvzNgHR"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Calculating MSE!\n",
"Calculating MSE!\n",
"Calculating MSE!\n"
]
}
],
"source": [
"# Observe what is printed below.\n",
"error = get_MSE(y_true, y_pred)\n",
"error = get_MSE(y_true, y_pred)\n",
"error = get_MSE(y_true, y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.727840Z",
"iopub.status.busy": "2024-08-15T01:24:03.727615Z",
"iopub.status.idle": "2024-08-15T01:24:03.730739Z",
"shell.execute_reply": "2024-08-15T01:24:03.730099Z"
},
"id": "_Df6ynXcAaup"
},
"outputs": [],
"source": [
"tf.config.run_functions_eagerly(False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PUR7qC_bquCn"
},
"source": [
"`print` is a *Python side effect*, and there are other differences that you should be aware of when converting a function into a `tf.function`. Learn more in the _Limitations_ section of the [Better performance with `tf.function`](./function.ipynb) guide."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oTZJfV_tccVp"
},
"source": [
"Note: If you would like to print values in both eager and graph execution, use `tf.print` instead."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rMT_Xf5yKn9o"
},
"source": [
"### Non-strict execution\n",
"\n",
"\n",
"\n",
"Graph execution only executes the operations necessary to produce the observable effects, which include:\n",
"\n",
"- The return value of the function\n",
"- Documented well-known side-effects such as:\n",
" - Input/output operations, like `tf.print`\n",
" - Debugging operations, such as the assert functions in `tf.debugging`\n",
" - Mutations of `tf.Variable`\n",
"\n",
"This behavior is usually known as \"Non-strict execution\", and differs from eager execution, which steps through all of the program operations, needed or not.\n",
"\n",
"In particular, runtime error checking does not count as an observable effect. If an operation is skipped because it is unnecessary, it cannot raise any runtime errors.\n",
"\n",
"In the following example, the \"unnecessary\" operation `tf.gather` is skipped during graph execution, so the runtime error `InvalidArgumentError` is not raised as it would be in eager execution. Do not rely on an error being raised while executing a graph."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.734161Z",
"iopub.status.busy": "2024-08-15T01:24:03.733536Z",
"iopub.status.idle": "2024-08-15T01:24:03.741718Z",
"shell.execute_reply": "2024-08-15T01:24:03.741132Z"
},
"id": "OdN0nKlUwj7M"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([0.], shape=(1,), dtype=float32)\n"
]
}
],
"source": [
"def unused_return_eager(x):\n",
" # Get index 1 will fail when `len(x) == 1`\n",
" tf.gather(x, [1]) # unused \n",
" return x\n",
"\n",
"try:\n",
" print(unused_return_eager(tf.constant([0.0])))\n",
"except tf.errors.InvalidArgumentError as e:\n",
" # All operations are run during eager execution so an error is raised.\n",
" print(f'{type(e).__name__}: {e}')"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.744451Z",
"iopub.status.busy": "2024-08-15T01:24:03.744229Z",
"iopub.status.idle": "2024-08-15T01:24:03.782763Z",
"shell.execute_reply": "2024-08-15T01:24:03.782182Z"
},
"id": "d80Fob4MwhTs"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([0.], shape=(1,), dtype=float32)\n"
]
}
],
"source": [
"@tf.function\n",
"def unused_return_graph(x):\n",
" tf.gather(x, [1]) # unused\n",
" return x\n",
"\n",
"# Only needed operations are run during graph execution. The error is not raised.\n",
"print(unused_return_graph(tf.constant([0.0])))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "def6MupG9R0O"
},
"source": [
"### `tf.function` best practices\n",
"\n",
"It may take some time to get used to the behavior of `tf.function`. To get started quickly, first-time users should play around with decorating toy functions with `@tf.function` to get experience with going from eager to graph execution.\n",
"\n",
"*Designing for `tf.function`* may be your best bet for writing graph-compatible TensorFlow programs. Here are some tips:\n",
"- Toggle between eager and graph execution early and often with `tf.config.run_functions_eagerly` to pinpoint if/ when the two modes diverge.\n",
"- Create `tf.Variable`s\n",
"outside the Python function and modify them on the inside. The same goes for objects that use `tf.Variable`, like `tf.keras.layers`, `tf.keras.Model`s and `tf.keras.optimizers`.\n",
"- Avoid writing functions that depend on outer Python variables, excluding `tf.Variable`s and Keras objects. Learn more in _Depending on Python global and free variables_ of the [`tf.function` guide](./function.ipynb).\n",
"- Prefer to write functions which take tensors and other TensorFlow types as input. You can pass in other object types but be careful! Learn more in _Depending on Python objects_ of the [`tf.function` guide](./function.ipynb).\n",
"- Include as much computation as possible under a `tf.function` to maximize the performance gain. For example, decorate a whole training step or the entire training loop.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ViM3oBJVJrDx"
},
"source": [
"## Seeing the speed-up"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A6NHDp7vAKcJ"
},
"source": [
"`tf.function` usually improves the performance of your code, but the amount of speed-up depends on the kind of computation you run. Small computations can be dominated by the overhead of calling a graph. You can measure the difference in performance like so:"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.786040Z",
"iopub.status.busy": "2024-08-15T01:24:03.785458Z",
"iopub.status.idle": "2024-08-15T01:24:03.789876Z",
"shell.execute_reply": "2024-08-15T01:24:03.789204Z"
},
"id": "jr7p1BBjauPK"
},
"outputs": [],
"source": [
"x = tf.random.uniform(shape=[10, 10], minval=-1, maxval=2, dtype=tf.dtypes.int32)\n",
"\n",
"def power(x, y):\n",
" result = tf.eye(10, dtype=tf.dtypes.int32)\n",
" for _ in range(y):\n",
" result = tf.matmul(x, result)\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:03.792633Z",
"iopub.status.busy": "2024-08-15T01:24:03.792400Z",
"iopub.status.idle": "2024-08-15T01:24:07.899084Z",
"shell.execute_reply": "2024-08-15T01:24:07.898414Z"
},
"id": "ms2yJyAnUYxK"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Eager execution: 4.1027931490000356 seconds\n"
]
}
],
"source": [
"print(\"Eager execution:\", timeit.timeit(lambda: power(x, 100), number=1000), \"seconds\")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:07.902306Z",
"iopub.status.busy": "2024-08-15T01:24:07.902046Z",
"iopub.status.idle": "2024-08-15T01:24:08.702266Z",
"shell.execute_reply": "2024-08-15T01:24:08.701616Z"
},
"id": "gUB2mTyRYRAe"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Graph execution: 0.7951284349999241 seconds\n"
]
}
],
"source": [
"power_as_graph = tf.function(power)\n",
"print(\"Graph execution:\", timeit.timeit(lambda: power_as_graph(x, 100), number=1000), \"seconds\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q1Pfo5YwwILi"
},
"source": [
"`tf.function` is commonly used to speed up training loops, and you can learn more about it in the _Speeding-up your training step with `tf.function`_ section of the [Writing a training loop from scratch](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch) with Keras guide.\n",
"\n",
"Note: You can also try `tf.function(jit_compile=True)` for a more significant performance boost, especially if your code is heavy on TensorFlow control flow and uses many small tensors. Learn more in the _Explicit compilation with `tf.function(jit_compile=True)`_ section of the [XLA overview](https://www.tensorflow.org/xla)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sm0bNFp8PX53"
},
"source": [
"### Performance and trade-offs\n",
"\n",
"Graphs can speed up your code, but the process of creating them has some overhead. For some functions, the creation of the graph takes more time than the execution of the graph. **This investment is usually quickly paid back with the performance boost of subsequent executions, but it's important to be aware that the first few steps of any large model training can be slower due to tracing.**\n",
"\n",
"No matter how large your model, you want to avoid tracing frequently. In the _Controlling retracing_ section, the [`tf.function` guide](./function.ipynb) discusses how to set input specifications and use tensor arguments to avoid retracing. If you find you are getting unusually poor performance, it's a good idea to check if you are retracing accidentally."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "F4InDaTjwmBA"
},
"source": [
"## When is a `tf.function` tracing?\n",
"\n",
"To figure out when your `tf.function` is tracing, add a `print` statement to its code. As a rule of thumb, `tf.function` will execute the `print` statement every time it traces."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:08.705856Z",
"iopub.status.busy": "2024-08-15T01:24:08.705297Z",
"iopub.status.idle": "2024-08-15T01:24:08.750136Z",
"shell.execute_reply": "2024-08-15T01:24:08.749463Z"
},
"id": "hXtwlbpofLgW"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracing!\n",
"tf.Tensor(6, shape=(), dtype=int32)\n",
"tf.Tensor(11, shape=(), dtype=int32)\n"
]
}
],
"source": [
"@tf.function\n",
"def a_function_with_python_side_effect(x):\n",
" print(\"Tracing!\") # An eager-only side effect.\n",
" return x * x + tf.constant(2)\n",
"\n",
"# This is traced the first time.\n",
"print(a_function_with_python_side_effect(tf.constant(2)))\n",
"# The second time through, you won't see the side effect.\n",
"print(a_function_with_python_side_effect(tf.constant(3)))"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-15T01:24:08.753060Z",
"iopub.status.busy": "2024-08-15T01:24:08.752831Z",
"iopub.status.idle": "2024-08-15T01:24:08.783502Z",
"shell.execute_reply": "2024-08-15T01:24:08.782911Z"
},
"id": "inzSg8yzfNjl"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tracing!\n",
"tf.Tensor(6, shape=(), dtype=int32)\n",
"Tracing!\n",
"tf.Tensor(11, shape=(), dtype=int32)\n"
]
}
],
"source": [
"# This retraces each time the Python argument changes,\n",
"# as a Python argument could be an epoch count or other\n",
"# hyperparameter.\n",
"print(a_function_with_python_side_effect(2))\n",
"print(a_function_with_python_side_effect(3))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rtN8NW6AfKye"
},
"source": [
"New Python arguments always trigger the creation of a new graph, hence the extra tracing.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D1kbr5ocpS6R"
},
"source": [
"## Next steps\n",
"\n",
"You can learn more about `tf.function` on the API reference page and by following the [Better performance with `tf.function`](./function.ipynb) guide."
]
}
],
"metadata": {
"colab": {
"name": "intro_to_graphs.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 0
}