HTTP/2 200
content-type: application/octet-stream
x-guploader-uploadid: ABgVH8-Up6pR_CsoXASjlbIIRiXQY9pT0ERmuhXWYbeLkpuncL-BUUET4robClkSkakAmUA1
expires: Fri, 18 Jul 2025 22:03:32 GMT
date: Fri, 18 Jul 2025 21:03:32 GMT
cache-control: public, max-age=3600
last-modified: Sat, 09 Mar 2024 13:07:47 GMT
etag: "0b6ca84ec0d68d71a7da76a647b65718"
x-goog-generation: 1709989667655641
x-goog-metageneration: 1
x-goog-stored-content-encoding: identity
x-goog-stored-content-length: 31651
x-goog-hash: crc32c=pPIrWQ==
x-goog-hash: md5=C2yoTsDWjXGn2namR7ZXGA==
x-goog-storage-class: MULTI_REGIONAL
accept-ranges: bytes
content-length: 31651
server: UploadServer
alt-svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "1Pi_B2cvdBiW"
},
"source": [
"##### Copyright 2023 The TF-Agents Authors."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cellView": "form",
"execution": {
"iopub.execute_input": "2024-03-09T12:42:00.852981Z",
"iopub.status.busy": "2024-03-09T12:42:00.852404Z",
"iopub.status.idle": "2024-03-09T12:42:00.856284Z",
"shell.execute_reply": "2024-03-09T12:42:00.855681Z"
},
"id": "nQnmcm0oI1Q-"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GiI8CZYWcJ5n"
},
"source": [
"# Networks\n",
"\n",
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "31uij8nIo5bG"
},
"source": [
"## Introduction\n",
"\n",
"In this colab we will cover how to define custom networks for your agents. The networks help us define the model that is trained by agents. In TF-Agents you will find several different types of networks which are useful across agents:\n",
"\n",
"**Main Networks**\n",
"\n",
"* **QNetwork**: Used in Qlearning for environments with discrete actions, this network maps an observation to value estimates for each possible action.\n",
"* **CriticNetworks**: Also referred to as `ValueNetworks` in literature, learns to estimate some version of a Value function mapping some state into an estimate for the expected return of a policy. These networks estimate how good the state the agent is currently in is.\n",
"* **ActorNetworks**: Learn a mapping from observations to actions. These networks are usually used by our policies to generate actions.\n",
"* **ActorDistributionNetworks**: Similar to `ActorNetworks` but these generate a distribution which a policy can then sample to generate actions.\n",
"\n",
"**Helper Networks**\n",
"* **EncodingNetwork**: Allows users to easily define a mapping of pre-processing layers to apply to a network's input.\n",
"* **DynamicUnrollLayer**: Automatically resets the network's state on episode boundaries as it is applied over a time sequence.\n",
"* **ProjectionNetwork**: Networks like `CategoricalProjectionNetwork` or `NormalProjectionNetwork` take inputs and generate the required parameters to generate Categorical, or Normal distributions.\n",
"\n",
"All examples in TF-Agents come with pre-configured networks. However these networks are not setup to handle complex observations.\n",
"\n",
"If you have an environment which exposes more than one observation/action and you need to customize your networks then this tutorial is for you!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wmk1GBT9cPqC"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uhGeL1Kpc3Pw"
},
"source": [
"If you haven't installed tf-agents yet, run:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-09T12:42:00.859767Z",
"iopub.status.busy": "2024-03-09T12:42:00.859207Z",
"iopub.status.idle": "2024-03-09T12:42:10.449819Z",
"shell.execute_reply": "2024-03-09T12:42:10.448979Z"
},
"id": "xsLTHlVdiZP3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting tf-agents\r\n",
" Using cached tf_agents-0.19.0-py3-none-any.whl.metadata (12 kB)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: absl-py>=0.6.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.4.0)\r\n",
"Collecting cloudpickle>=1.3 (from tf-agents)\r\n",
" Using cached cloudpickle-3.0.0-py3-none-any.whl.metadata (7.0 kB)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting gin-config>=0.4.0 (from tf-agents)\r\n",
" Using cached gin_config-0.5.0-py3-none-any.whl.metadata (2.9 kB)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting gym<=0.23.0,>=0.17.0 (from tf-agents)\r\n",
" Using cached gym-0.23.0-py3-none-any.whl\r\n",
"Requirement already satisfied: numpy>=1.19.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.26.4)\r\n",
"Requirement already satisfied: pillow in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (10.2.0)\r\n",
"Requirement already satisfied: six>=1.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.16.0)\r\n",
"Requirement already satisfied: protobuf>=3.11.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (3.20.3)\r\n",
"Requirement already satisfied: wrapt>=1.11.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-agents) (1.16.0)\r\n",
"Collecting typing-extensions==4.5.0 (from tf-agents)\r\n",
" Using cached typing_extensions-4.5.0-py3-none-any.whl.metadata (8.5 kB)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting pygame==2.1.3 (from tf-agents)\r\n",
" Using cached pygame-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.3 kB)\r\n",
"Collecting tensorflow-probability~=0.23.0 (from tf-agents)\r\n",
" Using cached tensorflow_probability-0.23.0-py2.py3-none-any.whl.metadata (13 kB)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting gym-notices>=0.0.4 (from gym<=0.23.0,>=0.17.0->tf-agents)\r\n",
" Using cached gym_notices-0.0.8-py3-none-any.whl.metadata (1.0 kB)\r\n",
"Requirement already satisfied: importlib-metadata>=4.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from gym<=0.23.0,>=0.17.0->tf-agents) (7.0.2)\r\n",
"Requirement already satisfied: decorator in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents) (5.1.1)\r\n",
"Requirement already satisfied: gast>=0.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents) (0.5.4)\r\n",
"Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-probability~=0.23.0->tf-agents) (0.1.8)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.10.0->gym<=0.23.0,>=0.17.0->tf-agents) (3.17.0)\r\n",
"Using cached tf_agents-0.19.0-py3-none-any.whl (1.4 MB)\r\n",
"Using cached pygame-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.7 MB)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using cached typing_extensions-4.5.0-py3-none-any.whl (27 kB)\r\n",
"Using cached cloudpickle-3.0.0-py3-none-any.whl (20 kB)\r\n",
"Using cached gin_config-0.5.0-py3-none-any.whl (61 kB)\r\n",
"Using cached tensorflow_probability-0.23.0-py2.py3-none-any.whl (6.9 MB)\r\n",
"Using cached gym_notices-0.0.8-py3-none-any.whl (3.0 kB)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Installing collected packages: gym-notices, gin-config, typing-extensions, pygame, cloudpickle, tensorflow-probability, gym, tf-agents\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Attempting uninstall: typing-extensions\r\n",
" Found existing installation: typing_extensions 4.10.0\r\n",
" Uninstalling typing_extensions-4.10.0:\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Successfully uninstalled typing_extensions-4.10.0\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Successfully installed cloudpickle-3.0.0 gin-config-0.5.0 gym-0.23.0 gym-notices-0.0.8 pygame-2.1.3 tensorflow-probability-0.23.0 tf-agents-0.19.0 typing-extensions-4.5.0\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: tf-keras in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (2.16.0)\r\n",
"Requirement already satisfied: tensorflow<2.17,>=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-keras) (2.16.1)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: absl-py>=1.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.4.0)\r\n",
"Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.6.3)\r\n",
"Requirement already satisfied: flatbuffers>=23.5.26 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (24.3.7)\r\n",
"Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.5.4)\r\n",
"Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.2.0)\r\n",
"Requirement already satisfied: h5py>=3.10.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.10.0)\r\n",
"Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (16.0.6)\r\n",
"Requirement already satisfied: ml-dtypes~=0.3.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.3.2)\r\n",
"Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.3.0)\r\n",
"Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (23.2)\r\n",
"Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.20.3)\r\n",
"Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (2.31.0)\r\n",
"Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (69.1.1)\r\n",
"Requirement already satisfied: six>=1.12.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.16.0)\r\n",
"Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (2.4.0)\r\n",
"Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (4.5.0)\r\n",
"Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.16.0)\r\n",
"Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.62.1)\r\n",
"Requirement already satisfied: tensorboard<2.17,>=2.16 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (2.16.2)\r\n",
"Requirement already satisfied: keras>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (3.0.5)\r\n",
"Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (0.36.0)\r\n",
"Requirement already satisfied: numpy<2.0.0,>=1.23.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<2.17,>=2.16->tf-keras) (1.26.4)\r\n",
"Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tensorflow<2.17,>=2.16->tf-keras) (0.41.2)\r\n",
"Requirement already satisfied: rich in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (13.7.1)\r\n",
"Requirement already satisfied: namex in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (0.0.7)\r\n",
"Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (0.1.8)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (3.3.2)\r\n",
"Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (3.6)\r\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (2.2.1)\r\n",
"Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow<2.17,>=2.16->tf-keras) (2024.2.2)\r\n",
"Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (3.5.2)\r\n",
"Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (0.7.2)\r\n",
"Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (3.0.1)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (7.0.2)\r\n",
"Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (2.1.5)\r\n",
"Requirement already satisfied: markdown-it-py>=2.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (3.0.0)\r\n",
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (2.17.2)\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.17,>=2.16->tensorflow<2.17,>=2.16->tf-keras) (3.17.0)\r\n",
"Requirement already satisfied: mdurl~=0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.0.0->tensorflow<2.17,>=2.16->tf-keras) (0.1.2)\r\n"
]
}
],
"source": [
"!pip install tf-agents\n",
"!pip install tf-keras"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-09T12:42:10.454582Z",
"iopub.status.busy": "2024-03-09T12:42:10.453853Z",
"iopub.status.idle": "2024-03-09T12:42:10.457625Z",
"shell.execute_reply": "2024-03-09T12:42:10.457055Z"
},
"id": "WPuD0bMEY9Iz"
},
"outputs": [],
"source": [
"import os\n",
"# Keep using keras-2 (tf-keras) rather than keras-3 (keras).\n",
"os.environ['TF_USE_LEGACY_KERAS'] = '1'"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-09T12:42:10.460778Z",
"iopub.status.busy": "2024-03-09T12:42:10.460448Z",
"iopub.status.idle": "2024-03-09T12:42:13.347895Z",
"shell.execute_reply": "2024-03-09T12:42:13.347116Z"
},
"id": "sdvop99JlYSM"
},
"outputs": [],
"source": [
"from __future__ import absolute_import\n",
"from __future__ import division\n",
"from __future__ import print_function\n",
"\n",
"import abc\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"\n",
"from tf_agents.environments import random_py_environment\n",
"from tf_agents.environments import tf_py_environment\n",
"from tf_agents.networks import encoding_network\n",
"from tf_agents.networks import network\n",
"from tf_agents.networks import utils\n",
"from tf_agents.specs import array_spec\n",
"from tf_agents.utils import common as common_utils\n",
"from tf_agents.utils import nest_utils"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ums84-YP_21F"
},
"source": [
"## Defining Networks\n",
"\n",
"### Network API\n",
"\n",
"In TF-Agents we subclass from Keras [Networks](https://github.com/tensorflow/agents/blob/master/tf_agents/networks/network.py). With it we can:\n",
"\n",
"* Simplify copy operations required when creating target networks.\n",
"* Perform automatic variable creation when calling `network.variables()`.\n",
"* Validate inputs based on network input_specs.\n",
"\n",
"##EncodingNetwork\n",
"As mentioned above the `EncodingNetwork` allows us to easily define a mapping of pre-processing layers to apply to a network's input to generate some encoding.\n",
"\n",
"The EncodingNetwork is composed of the following mostly optional layers:\n",
"\n",
" * Preprocessing layers\n",
" * Preprocessing combiner\n",
" * Conv2D \n",
" * Flatten\n",
" * Dense \n",
"\n",
"The special thing about encoding networks is that input preprocessing is applied. Input preprocessing is possible via `preprocessing_layers` and `preprocessing_combiner` layers. Each of these can be specified as a nested structure. If the `preprocessing_layers` nest is shallower than `input_tensor_spec`, then the layers will get the subnests. For example, if:\n",
"\n",
"```\n",
"input_tensor_spec = ([TensorSpec(3)] * 2, [TensorSpec(3)] * 5)\n",
"preprocessing_layers = (Layer1(), Layer2())\n",
"```\n",
"\n",
"then preprocessing will call:\n",
"\n",
"```\n",
"preprocessed = [preprocessing_layers[0](observations[0]),\n",
" preprocessing_layers[1](observations[1])]\n",
"```\n",
"\n",
"However if\n",
"\n",
"```\n",
"preprocessing_layers = ([Layer1() for _ in range(2)],\n",
" [Layer2() for _ in range(5)])\n",
"```\n",
"\n",
"then preprocessing will call:\n",
"\n",
"```python\n",
"preprocessed = [\n",
" layer(obs) for layer, obs in zip(flatten(preprocessing_layers),\n",
" flatten(observations))\n",
"]\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RP3H1bw0ykro"
},
"source": [
"### Custom Networks\n",
"\n",
"To create your own networks you will only have to override the `__init__` and `call` methods. Let's create a custom network using what we learned about `EncodingNetworks` to create an ActorNetwork that takes observations which contain an image and a vector.\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-09T12:42:13.352703Z",
"iopub.status.busy": "2024-03-09T12:42:13.351965Z",
"iopub.status.idle": "2024-03-09T12:42:13.362284Z",
"shell.execute_reply": "2024-03-09T12:42:13.361661Z"
},
"id": "Zp0TjAJhYo4s"
},
"outputs": [],
"source": [
"class ActorNetwork(network.Network):\n",
"\n",
" def __init__(self,\n",
" observation_spec,\n",
" action_spec,\n",
" preprocessing_layers=None,\n",
" preprocessing_combiner=None,\n",
" conv_layer_params=None,\n",
" fc_layer_params=(75, 40),\n",
" dropout_layer_params=None,\n",
" activation_fn=tf.keras.activations.relu,\n",
" enable_last_layer_zero_initializer=False,\n",
" name='ActorNetwork'):\n",
" super(ActorNetwork, self).__init__(\n",
" input_tensor_spec=observation_spec, state_spec=(), name=name)\n",
"\n",
" # For simplicity we will only support a single action float output.\n",
" self._action_spec = action_spec\n",
" flat_action_spec = tf.nest.flatten(action_spec)\n",
" if len(flat_action_spec) > 1:\n",
" raise ValueError('Only a single action is supported by this network')\n",
" self._single_action_spec = flat_action_spec[0]\n",
" if self._single_action_spec.dtype not in [tf.float32, tf.float64]:\n",
" raise ValueError('Only float actions are supported by this network.')\n",
"\n",
" kernel_initializer = tf.keras.initializers.VarianceScaling(\n",
" scale=1. / 3., mode='fan_in', distribution='uniform')\n",
" self._encoder = encoding_network.EncodingNetwork(\n",
" observation_spec,\n",
" preprocessing_layers=preprocessing_layers,\n",
" preprocessing_combiner=preprocessing_combiner,\n",
" conv_layer_params=conv_layer_params,\n",
" fc_layer_params=fc_layer_params,\n",
" dropout_layer_params=dropout_layer_params,\n",
" activation_fn=activation_fn,\n",
" kernel_initializer=kernel_initializer,\n",
" batch_squash=False)\n",
"\n",
" initializer = tf.keras.initializers.RandomUniform(\n",
" minval=-0.003, maxval=0.003)\n",
"\n",
" self._action_projection_layer = tf.keras.layers.Dense(\n",
" flat_action_spec[0].shape.num_elements(),\n",
" activation=tf.keras.activations.tanh,\n",
" kernel_initializer=initializer,\n",
" name='action')\n",
"\n",
" def call(self, observations, step_type=(), network_state=()):\n",
" outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec)\n",
" # We use batch_squash here in case the observations have a time sequence\n",
" # compoment.\n",
" batch_squash = utils.BatchSquash(outer_rank)\n",
" observations = tf.nest.map_structure(batch_squash.flatten, observations)\n",
"\n",
" state, network_state = self._encoder(\n",
" observations, step_type=step_type, network_state=network_state)\n",
" actions = self._action_projection_layer(state)\n",
" actions = common_utils.scale_to_spec(actions, self._single_action_spec)\n",
" actions = batch_squash.unflatten(actions)\n",
" return tf.nest.pack_sequence_as(self._action_spec, [actions]), network_state"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fm-MbMMLYiZj"
},
"source": [
"Let's create a `RandomPyEnvironment` to generate structured observations and validate our implementation."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-09T12:42:13.365785Z",
"iopub.status.busy": "2024-03-09T12:42:13.365280Z",
"iopub.status.idle": "2024-03-09T12:42:13.378607Z",
"shell.execute_reply": "2024-03-09T12:42:13.377981Z"
},
"id": "E2XoNuuD66s5"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_agents/specs/array_spec.py:352: RuntimeWarning: invalid value encountered in cast\n",
" self._minimum[self._minimum == -np.inf] = low\n",
"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_agents/specs/array_spec.py:353: RuntimeWarning: invalid value encountered in cast\n",
" self._minimum[self._minimum == np.inf] = high\n",
"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_agents/specs/array_spec.py:355: RuntimeWarning: invalid value encountered in cast\n",
" self._maximum[self._maximum == -np.inf] = low\n",
"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_agents/specs/array_spec.py:356: RuntimeWarning: invalid value encountered in cast\n",
" self._maximum[self._maximum == np.inf] = high\n"
]
}
],
"source": [
"action_spec = array_spec.BoundedArraySpec((3,), np.float32, minimum=0, maximum=10)\n",
"observation_spec = {\n",
" 'image': array_spec.BoundedArraySpec((16, 16, 3), np.float32, minimum=0,\n",
" maximum=255),\n",
" 'vector': array_spec.BoundedArraySpec((5,), np.float32, minimum=-100,\n",
" maximum=100)}\n",
"\n",
"random_env = random_py_environment.RandomPyEnvironment(observation_spec, action_spec=action_spec)\n",
"\n",
"# Convert the environment to a TFEnv to generate tensors.\n",
"tf_env = tf_py_environment.TFPyEnvironment(random_env)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LM3uDTD7TNVx"
},
"source": [
"Since we've defined the observations to be a dict we need to create preprocessing layers to handle these."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-09T12:42:13.381891Z",
"iopub.status.busy": "2024-03-09T12:42:13.381426Z",
"iopub.status.idle": "2024-03-09T12:42:16.233111Z",
"shell.execute_reply": "2024-03-09T12:42:16.232360Z"
},
"id": "r9U6JVevTAJw"
},
"outputs": [],
"source": [
"preprocessing_layers = {\n",
" 'image': tf.keras.models.Sequential([tf.keras.layers.Conv2D(8, 4),\n",
" tf.keras.layers.Flatten()]),\n",
" 'vector': tf.keras.layers.Dense(5)\n",
" }\n",
"preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)\n",
"actor = ActorNetwork(tf_env.observation_spec(), \n",
" tf_env.action_spec(),\n",
" preprocessing_layers=preprocessing_layers,\n",
" preprocessing_combiner=preprocessing_combiner)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mM9qedlwc41U"
},
"source": [
"Now that we have the actor network we can process observations from the environment."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2024-03-09T12:42:16.237409Z",
"iopub.status.busy": "2024-03-09T12:42:16.237119Z",
"iopub.status.idle": "2024-03-09T12:42:17.228799Z",
"shell.execute_reply": "2024-03-09T12:42:17.228142Z"
},
"id": "JOkkeu7vXoei"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_keras/src/initializers/initializers.py:121: UserWarning: The initializer VarianceScaling is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"(
,\n",
" ())"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"time_step = tf_env.reset()\n",
"actor(time_step.observation, time_step.step_type)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ALGxaQLWc9GI"
},
"source": [
"This same strategy can be used to customize any of the main networks used by the agents. You can define whatever preprocessing and connect it to the rest of the network. As you define your own custom make sure the output layer definitions of the network match."
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "TF-Agents Networks Tutorial.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 0
}