HTTP/2 200
content-type: application/octet-stream
x-guploader-uploadid: ABgVH8_gr8LSL97jdkiBtuVIHw1Cjgey0FINHrdV93S8SzrYwGfYfZ4HIZmOJVX064MB49Di
expires: Fri, 18 Jul 2025 21:50:19 GMT
date: Fri, 18 Jul 2025 20:50:19 GMT
cache-control: public, max-age=3600
last-modified: Sat, 12 Feb 2022 02:16:47 GMT
etag: "5c0f943fcdab581d1c5d1537fb017e27"
x-goog-generation: 1644632207792446
x-goog-metageneration: 1
x-goog-stored-content-encoding: identity
x-goog-stored-content-length: 34882
x-goog-hash: crc32c=wiZ1kg==
x-goog-hash: md5=XA+UP82rWB0cXRU3+wF+Jw==
x-goog-storage-class: MULTI_REGIONAL
accept-ranges: bytes
content-length: 34882
server: UploadServer
alt-svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Tce3stUlHN0L"
},
"source": [
"##### Copyright 2021 The TensorFlow IO Authors."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cellView": "form",
"id": "tuOe1ymfHZPu"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qFdPvlXBOdUN"
},
"source": [
"# Tensorflow datasets from MongoDB collections "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MfBg1C5NB3X0"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xHxb-dlhMIzW"
},
"source": [
"## Overview\n",
"\n",
"This tutorial focuses on preparing `tf.data.Dataset`s by reading data from mongoDB collections and using it for training a `tf.keras` model.\n",
"\n",
"**NOTE:** A basic understanding of [mongodb storage](https://docs.mongodb.com/guides/) will help you in following the tutorial with ease."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MUXex9ctTuDB"
},
"source": [
"## Setup packages\n",
"\n",
"This tutorial uses `pymongo` as a helper package to create a new mongodb database and collection to store the data. \n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "upgCc3gXybsA"
},
"source": [
"### Install the required tensorflow-io and mongodb (helper) packages"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "48B9eAMMhAgw"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: Ignoring invalid distribution -eras (/usr/local/lib/python3.7/dist-packages)\u001b[0m\n",
"\u001b[33mWARNING: Ignoring invalid distribution -eras (/usr/local/lib/python3.7/dist-packages)\u001b[0m\n",
"\u001b[33mWARNING: Ignoring invalid distribution -eras (/usr/local/lib/python3.7/dist-packages)\u001b[0m\n",
"\u001b[33m WARNING: Ignoring invalid distribution -eras (/usr/local/lib/python3.7/dist-packages)\u001b[0m\n",
"\u001b[33mWARNING: Ignoring invalid distribution -eras (/usr/local/lib/python3.7/dist-packages)\u001b[0m\n",
"\u001b[33mWARNING: Ignoring invalid distribution -eras (/usr/local/lib/python3.7/dist-packages)\u001b[0m\n",
"\u001b[33mWARNING: Ignoring invalid distribution -eras (/usr/local/lib/python3.7/dist-packages)\u001b[0m\n",
"\u001b[33mWARNING: Ignoring invalid distribution -eras (/usr/local/lib/python3.7/dist-packages)\u001b[0m\n",
"\u001b[33mWARNING: Ignoring invalid distribution -eras (/usr/local/lib/python3.7/dist-packages)\u001b[0m\n",
"\u001b[33mWARNING: Ignoring invalid distribution -eras (/usr/local/lib/python3.7/dist-packages)\u001b[0m\n",
"\u001b[33mWARNING: Ignoring invalid distribution -eras (/usr/local/lib/python3.7/dist-packages)\u001b[0m\n",
"\u001b[33mWARNING: Ignoring invalid distribution -eras (/usr/local/lib/python3.7/dist-packages)\u001b[0m\n"
]
}
],
"source": [
"!pip install -q tensorflow-io\n",
"!pip install -q pymongo"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gjrZNJQRJP-U"
},
"source": [
"### Import packages"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "m6KXZuTBWgRm"
},
"outputs": [],
"source": [
"import os\n",
"import time\n",
"from pprint import pprint\n",
"from sklearn.model_selection import train_test_split\n",
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"from tensorflow.keras import layers\n",
"from tensorflow.keras.layers.experimental import preprocessing\n",
"import tensorflow_io as tfio\n",
"from pymongo import MongoClient"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eCgO11GTJaTj"
},
"source": [
"### Validate tf and tfio imports"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "dX74RKfZ_TdF"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensorflow-io version: 0.20.0\n",
"tensorflow version: 2.6.0\n"
]
}
],
"source": [
"print(\"tensorflow-io version: {}\".format(tfio.__version__))\n",
"print(\"tensorflow version: {}\".format(tf.__version__))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yZmI7l_GykcW"
},
"source": [
"## Download and setup the MongoDB instance\n",
"\n",
"For demo purposes, the open-source version of mongodb is used.\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "YUj0878jPyz7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" * Starting database mongodb\n",
" ...done.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"WARNING: apt does not have a stable CLI interface. Use with caution in scripts.\n",
"\n",
"debconf: unable to initialize frontend: Dialog\n",
"debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 76, <> line 8.)\n",
"debconf: falling back to frontend: Readline\n",
"debconf: unable to initialize frontend: Readline\n",
"debconf: (This frontend requires a controlling tty.)\n",
"debconf: falling back to frontend: Teletype\n",
"dpkg-preconfigure: unable to re-open stdin: \n"
]
}
],
"source": [
"%%bash\n",
"\n",
"sudo apt install -y mongodb >log\n",
"service mongodb start"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "XyUa9r6MgWtW"
},
"outputs": [],
"source": [
"# Sleep for few seconds to let the instance start.\n",
"time.sleep(5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f6qxCdypE1DD"
},
"source": [
"Once the instance has been started, grep for `mongo` in the processes list to confirm the availability."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "48LqMJ1BEHm5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mongodb 580 1 13 17:38 ? 00:00:00 /usr/bin/mongod --config /etc/mongodb.conf\n",
"root 612 610 0 17:38 ? 00:00:00 grep mongo\n"
]
}
],
"source": [
"%%bash\n",
"\n",
"ps -ef | grep mongo"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wBuRpiyf_kNS"
},
"source": [
"query the base endpoint to retrieve information about the cluster."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "m8EH1-N-idTn"
},
"outputs": [
{
"data": {
"text/plain": [
"['admin', 'local']"
]
},
"execution_count": 8,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"client = MongoClient()\n",
"client.list_database_names() # ['admin', 'local']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4CfKVmCvwcL7"
},
"source": [
"### Explore the dataset\n",
"\n",
"For the purpose of this tutorial, lets download the [PetFinder](https://www.kaggle.com/c/petfinder-adoption-prediction) dataset and feed the data into mongodb manually. The goal of this classification problem is predict if the pet will be adopted or not.\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "XkXyocIdKRSB"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/petfinder-mini.zip\n",
"1671168/1668792 [==============================] - 0s 0us/step\n",
"1679360/1668792 [==============================] - 0s 0us/step\n"
]
}
],
"source": [
"dataset_url = 'https://storage.googleapis.com/download.tensorflow.org/data/petfinder-mini.zip'\n",
"csv_file = 'datasets/petfinder-mini/petfinder-mini.csv'\n",
"tf.keras.utils.get_file('petfinder_mini.zip', dataset_url,\n",
" extract=True, cache_dir='.')\n",
"pf_df = pd.read_csv(csv_file)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "nC-yt_c9u0sH"
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Type | \n",
" Age | \n",
" Breed1 | \n",
" Gender | \n",
" Color1 | \n",
" Color2 | \n",
" MaturitySize | \n",
" FurLength | \n",
" Vaccinated | \n",
" Sterilized | \n",
" Health | \n",
" Fee | \n",
" Description | \n",
" PhotoAmt | \n",
" AdoptionSpeed | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Cat | \n",
" 3 | \n",
" Tabby | \n",
" Male | \n",
" Black | \n",
" White | \n",
" Small | \n",
" Short | \n",
" No | \n",
" No | \n",
" Healthy | \n",
" 100 | \n",
" Nibble is a 3+ month old ball of cuteness. He ... | \n",
" 1 | \n",
" 2 | \n",
"
\n",
" \n",
" 1 | \n",
" Cat | \n",
" 1 | \n",
" Domestic Medium Hair | \n",
" Male | \n",
" Black | \n",
" Brown | \n",
" Medium | \n",
" Medium | \n",
" Not Sure | \n",
" Not Sure | \n",
" Healthy | \n",
" 0 | \n",
" I just found it alone yesterday near my apartm... | \n",
" 2 | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" Dog | \n",
" 1 | \n",
" Mixed Breed | \n",
" Male | \n",
" Brown | \n",
" White | \n",
" Medium | \n",
" Medium | \n",
" Yes | \n",
" No | \n",
" Healthy | \n",
" 0 | \n",
" Their pregnant mother was dumped by her irresp... | \n",
" 7 | \n",
" 3 | \n",
"
\n",
" \n",
" 3 | \n",
" Dog | \n",
" 4 | \n",
" Mixed Breed | \n",
" Female | \n",
" Black | \n",
" Brown | \n",
" Medium | \n",
" Short | \n",
" Yes | \n",
" No | \n",
" Healthy | \n",
" 150 | \n",
" Good guard dog, very alert, active, obedience ... | \n",
" 8 | \n",
" 2 | \n",
"
\n",
" \n",
" 4 | \n",
" Dog | \n",
" 1 | \n",
" Mixed Breed | \n",
" Male | \n",
" Black | \n",
" No Color | \n",
" Medium | \n",
" Short | \n",
" No | \n",
" No | \n",
" Healthy | \n",
" 0 | \n",
" This handsome yet cute boy is up for adoption.... | \n",
" 3 | \n",
" 2 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Type Age ... PhotoAmt AdoptionSpeed\n",
"0 Cat 3 ... 1 2\n",
"1 Cat 1 ... 2 0\n",
"2 Dog 1 ... 7 3\n",
"3 Dog 4 ... 8 2\n",
"4 Dog 1 ... 3 2\n",
"\n",
"[5 rows x 15 columns]"
]
},
"execution_count": 10,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"pf_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FTFL8nmnGVOc"
},
"source": [
"For the purpose of the tutorial, modifications are made to the label column.\n",
"0 will indicate the pet was not adopted, and 1 will indicate that it was.\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "c6Cg22bU0-na"
},
"outputs": [],
"source": [
"# In the original dataset \"4\" indicates the pet was not adopted.\n",
"pf_df['target'] = np.where(pf_df['AdoptionSpeed']==4, 0, 1)\n",
"\n",
"# Drop un-used columns.\n",
"pf_df = pf_df.drop(columns=['AdoptionSpeed', 'Description'])\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "klnNOM5oGtH1"
},
"outputs": [
{
"data": {
"text/plain": [
"(11537, 14)"
]
},
"execution_count": 12,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"# Number of datapoints and columns\n",
"len(pf_df), len(pf_df.columns)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tF5K9xtmlT2P"
},
"source": [
"### Split the dataset\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "n-ku_X0Wld59"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of training samples: 8075\n",
"Number of testing sample: 3462\n"
]
}
],
"source": [
"train_df, test_df = train_test_split(pf_df, test_size=0.3, shuffle=True)\n",
"print(\"Number of training samples: \",len(train_df))\n",
"print(\"Number of testing sample: \",len(test_df))\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wwP5U4GqmhoL"
},
"source": [
"### Store the train and test data in mongo collections"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "edzds_qwk0Id"
},
"outputs": [],
"source": [
"URI = \"mongodb://localhost:27017\"\n",
"DATABASE = \"tfiodb\"\n",
"TRAIN_COLLECTION = \"train\"\n",
"TEST_COLLECTION = \"test\""
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "x6eT1wHykRKq"
},
"outputs": [],
"source": [
"db = client[DATABASE]\n",
"if \"train\" not in db.list_collection_names():\n",
" db.create_collection(TRAIN_COLLECTION)\n",
"if \"test\" not in db.list_collection_names():\n",
" db.create_collection(TEST_COLLECTION)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "YhwFImSqncLE"
},
"outputs": [],
"source": [
"def store_records(collection, records):\n",
" writer = tfio.experimental.mongodb.MongoDBWriter(\n",
" uri=URI, database=DATABASE, collection=collection\n",
" )\n",
" for record in records:\n",
" writer.write(record)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "4wBiwCRBNGAu"
},
"outputs": [],
"source": [
"store_records(collection=\"train\", records=train_df.to_dict(\"records\"))\n",
"time.sleep(2)\n",
"store_records(collection=\"test\", records=test_df.to_dict(\"records\"))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2mOrfOYrHpQj"
},
"source": [
"## Prepare tfio datasets\n",
"\n",
"Once the data is available in the cluster, the `mongodb.MongoDBIODataset` class is utilized for this purpose. The class inherits from `tf.data.Dataset` and thus exposes all the useful functionalities of `tf.data.Dataset` out of the box.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "58q52py93jEf"
},
"source": [
"### Training dataset\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"id": "HHOcitbW2_d1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Connection successful: mongodb://localhost:27017\n",
"WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow/python/data/experimental/ops/counter.py:66: scan (from tensorflow.python.data.experimental.ops.scan_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.data.Dataset.scan(...) instead\n",
"WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow_io/python/experimental/mongodb_dataset_ops.py:114: take_while (from tensorflow.python.data.experimental.ops.take_while_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.data.Dataset.take_while(...)\n"
]
},
{
"data": {
"text/plain": [
"
"
]
},
"execution_count": 18,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"train_ds = tfio.experimental.mongodb.MongoDBIODataset(\n",
" uri=URI, database=DATABASE, collection=TRAIN_COLLECTION\n",
" )\n",
"\n",
"train_ds"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IdwGj48SqxXY"
},
"source": [
"Each item in `train_ds` is a string which needs to be decoded into a json. To do so, you can select only a subset of the columns by specifying the `TensorSpec`"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"id": "fZXMXXbJrHtk"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'Fee': TensorSpec(shape=(), dtype=tf.int32, name='Fee'),\n",
" 'PhotoAmt': TensorSpec(shape=(), dtype=tf.int32, name='PhotoAmt'),\n",
" 'target': TensorSpec(shape=(), dtype=tf.int64, name='target')}\n"
]
}
],
"source": [
"# Numeric features.\n",
"numerical_cols = ['PhotoAmt', 'Fee'] \n",
"\n",
"SPECS = {\n",
" \"target\": tf.TensorSpec(tf.TensorShape([]), tf.int64, name=\"target\"),\n",
"}\n",
"for col in numerical_cols:\n",
" SPECS[col] = tf.TensorSpec(tf.TensorShape([]), tf.int32, name=col)\n",
"pprint(SPECS)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"id": "8XNdh0Qyqbhl"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 20,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"BATCH_SIZE=32\n",
"train_ds = train_ds.map(\n",
" lambda x: tfio.experimental.serialization.decode_json(x, specs=SPECS)\n",
" )\n",
"\n",
"# Prepare a tuple of (features, label)\n",
"train_ds = train_ds.map(lambda v: (v, v.pop(\"target\")))\n",
"train_ds = train_ds.batch(BATCH_SIZE)\n",
"\n",
"train_ds"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Me0zgeCQIsKH"
},
"source": [
"### Testing dataset"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"id": "2R-I9hUgIcXR"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Connection successful: mongodb://localhost:27017\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 21,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"test_ds = tfio.experimental.mongodb.MongoDBIODataset(\n",
" uri=URI, database=DATABASE, collection=TEST_COLLECTION\n",
" )\n",
"test_ds = test_ds.map(\n",
" lambda x: tfio.experimental.serialization.decode_json(x, specs=SPECS)\n",
" )\n",
"# Prepare a tuple of (features, label)\n",
"test_ds = test_ds.map(lambda v: (v, v.pop(\"target\")))\n",
"test_ds = test_ds.batch(BATCH_SIZE)\n",
"\n",
"test_ds"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7fAC5HDERL4-"
},
"source": [
"### Define the keras preprocessing layers\n",
"\n",
"As per the [structured data tutorial](https://www.tensorflow.org/tutorials/structured_data/preprocessing_layers), it is recommended to use the [Keras Preprocessing Layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing) as they are more intuitive, and can be easily integrated with the models. However, the standard [feature_columns](https://www.tensorflow.org/api_docs/python/tf/feature_column) can also be used.\n",
"\n",
"For a better understanding of the `preprocessing_layers` in classifying structured data, please refer to the [structured data tutorial](https://www.tensorflow.org/tutorials/structured_data/preprocessing_layers)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"id": "CBzR7Li4SaQS"
},
"outputs": [],
"source": [
"def get_normalization_layer(name, dataset):\n",
" # Create a Normalization layer for our feature.\n",
" normalizer = preprocessing.Normalization(axis=None)\n",
"\n",
" # Prepare a Dataset that only yields our feature.\n",
" feature_ds = dataset.map(lambda x, y: x[name])\n",
"\n",
" # Learn the statistics of the data.\n",
" normalizer.adapt(feature_ds)\n",
"\n",
" return normalizer\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "M0X9LEKoUfbU"
},
"outputs": [],
"source": [
"all_inputs = []\n",
"encoded_features = []\n",
"\n",
"for header in numerical_cols:\n",
" numeric_col = tf.keras.Input(shape=(1,), name=header)\n",
" normalization_layer = get_normalization_layer(header, train_ds)\n",
" encoded_numeric_col = normalization_layer(numeric_col)\n",
" all_inputs.append(numeric_col)\n",
" encoded_features.append(encoded_numeric_col)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x84lZJY164RI"
},
"source": [
"## Build, compile and train the model\n"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"id": "uuHtpAMqLqmv"
},
"outputs": [],
"source": [
"# Set the parameters\n",
"\n",
"OPTIMIZER=\"adam\"\n",
"LOSS=tf.keras.losses.BinaryCrossentropy(from_logits=True)\n",
"METRICS=['accuracy']\n",
"EPOCHS=10\n"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"id": "7lBmxxuj63jZ"
},
"outputs": [],
"source": [
"# Convert the feature columns into a tf.keras layer\n",
"all_features = tf.keras.layers.concatenate(encoded_features)\n",
"\n",
"# design/build the model\n",
"x = tf.keras.layers.Dense(32, activation=\"relu\")(all_features)\n",
"x = tf.keras.layers.Dropout(0.5)(x)\n",
"x = tf.keras.layers.Dense(64, activation=\"relu\")(x)\n",
"x = tf.keras.layers.Dropout(0.5)(x)\n",
"output = tf.keras.layers.Dense(1)(x)\n",
"model = tf.keras.Model(all_inputs, output)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"id": "LTDFVxpSLfXI"
},
"outputs": [],
"source": [
"# compile the model\n",
"model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=METRICS)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"id": "SIJMg-saLgeR"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"109/109 [==============================] - 1s 2ms/step - loss: 0.6261 - accuracy: 0.4711\n",
"Epoch 2/10\n",
"109/109 [==============================] - 0s 3ms/step - loss: 0.5939 - accuracy: 0.6967\n",
"Epoch 3/10\n",
"109/109 [==============================] - 0s 3ms/step - loss: 0.5900 - accuracy: 0.6993\n",
"Epoch 4/10\n",
"109/109 [==============================] - 0s 3ms/step - loss: 0.5846 - accuracy: 0.7146\n",
"Epoch 5/10\n",
"109/109 [==============================] - 0s 3ms/step - loss: 0.5824 - accuracy: 0.7178\n",
"Epoch 6/10\n",
"109/109 [==============================] - 0s 2ms/step - loss: 0.5778 - accuracy: 0.7233\n",
"Epoch 7/10\n",
"109/109 [==============================] - 0s 3ms/step - loss: 0.5810 - accuracy: 0.7083\n",
"Epoch 8/10\n",
"109/109 [==============================] - 0s 3ms/step - loss: 0.5791 - accuracy: 0.7149\n",
"Epoch 9/10\n",
"109/109 [==============================] - 0s 3ms/step - loss: 0.5742 - accuracy: 0.7207\n",
"Epoch 10/10\n",
"109/109 [==============================] - 0s 2ms/step - loss: 0.5797 - accuracy: 0.7083\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 27,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"# fit the model\n",
"model.fit(train_ds, epochs=EPOCHS)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XYJW8za2qm4c"
},
"source": [
"## Infer on the test data"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"id": "6hMtIe1X215P"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"109/109 [==============================] - 0s 2ms/step - loss: 0.5696 - accuracy: 0.7383\n",
"test loss, test acc: [0.569588840007782, 0.7383015751838684]\n"
]
}
],
"source": [
"res = model.evaluate(test_ds)\n",
"print(\"test loss, test acc:\", res)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2SvFjOJcdRyO"
},
"source": [
"Note: Since the goal of this tutorial is to demonstrate Tensorflow-IO's capability to prepare `tf.data.Datasets` from mongodb and train `tf.keras` models directly, improving the accuracy of the models is out of the current scope. However, the user can explore the dataset and play around with the feature columns and model architectures to get a better classification performance."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P8QAS_3k1y3u"
},
"source": [
"## References:\n",
"\n",
"- [MongoDB](https://docs.mongodb.com/guides/)\n",
"\n",
"- [PetFinder Dataset](https://www.kaggle.com/c/petfinder-adoption-prediction)\n",
"\n",
"- [Classify Structured Data using Keras](https://www.tensorflow.org/tutorials/structured_data/preprocessing_layers#create_compile_and_train_the_model)\n"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "mongodb.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}