HTTP/2 200
x-guploader-uploadid: ABgVH88aY1hnJuRCp-wrFwSNzbae9n_ofWEMa3erzocD8yiUfsXJM5mgIh_neEMRBlGlT6U7KnFcY04
x-goog-generation: 1720775732196048
x-goog-metageneration: 1
x-goog-stored-content-encoding: identity
x-goog-stored-content-length: 19233
x-goog-hash: crc32c=48ru6Q==
x-goog-hash: md5=53FKl4g5z0FAxAmMiPxNfw==
x-goog-storage-class: MULTI_REGIONAL
accept-ranges: bytes
content-length: 19233
server: UploadServer
date: Tue, 22 Jul 2025 10:21:14 GMT
expires: Tue, 22 Jul 2025 11:21:14 GMT
cache-control: public, max-age=3600
last-modified: Fri, 12 Jul 2024 09:15:32 GMT
etag: "e7714a978839cf4140c4098c88fc4d7f"
content-type: application/octet-stream
age: 0
alt-svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "OoasdhSAp0zJ"
},
"source": [
"##### Copyright 2019 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "cIrwotvGqsYh"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "C81KT2D_j-xR"
},
"source": [
"# Build a linear model with Estimators\n",
"\n",
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JOccPOFMm5Tc"
},
"source": [
"> Warning: TensorFlow 2.15 included the final release of the `tf-estimator` package. Estimators will not be available in TensorFlow 2.16 or after. See the [migration guide](https://tensorflow.org/guide/migrate/migrating_estimator) for more information about how to convert off of Estimators."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tUP8LMdYtWPz"
},
"source": [
"## Overview\n",
"\n",
"This end-to-end walkthrough trains a logistic regression model using the `tf.estimator` API. The model is often used as a baseline for other, more complex, algorithms.\n",
"\n",
"Note: A Keras logistic regression example is [available](https://tensorflow.org/guide/migrate/tutorials/keras/regression) and is recommended over this tutorial.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vkC_j6VpqrDw"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rutbJGmpqvm3"
},
"outputs": [],
"source": [
"!pip install sklearn\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "54mb4J9PqqDh"
},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from IPython.display import clear_output\n",
"from six.moves import urllib"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fsjkwfsGOBMT"
},
"source": [
"## Load the titanic dataset\n",
"You will use the Titanic dataset with the (rather morbid) goal of predicting passenger survival, given characteristics such as gender, age, class, etc."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bNiwh-APcRVD"
},
"outputs": [],
"source": [
"import tensorflow.compat.v2.feature_column as fc\n",
"\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DSeMKcx03d5R"
},
"outputs": [],
"source": [
"# Load dataset.\n",
"dftrain = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')\n",
"dfeval = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')\n",
"y_train = dftrain.pop('survived')\n",
"y_eval = dfeval.pop('survived')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jjm4Qj0u7_cp"
},
"source": [
"## Explore the data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UrQzxKKh4d6u"
},
"source": [
"The dataset contains the following features"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rTjugo3n308g"
},
"outputs": [],
"source": [
"dftrain.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "y86q1fj44lZs"
},
"outputs": [],
"source": [
"dftrain.describe()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8JSa_duD4tFZ"
},
"source": [
"There are 627 and 264 examples in the training and evaluation sets, respectively."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Fs3Nu5pV4v5J"
},
"outputs": [],
"source": [
"dftrain.shape[0], dfeval.shape[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RxCA4Nr45AfF"
},
"source": [
"The majority of passengers are in their 20's and 30's."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RYeCMm7K40ZN"
},
"outputs": [],
"source": [
"dftrain.age.hist(bins=20)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DItSwJ_B5B0f"
},
"source": [
"There are approximately twice as many male passengers as female passengers aboard."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "b03dVV9q5Dv2"
},
"outputs": [],
"source": [
"dftrain.sex.value_counts().plot(kind='barh')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rK6WQ29q5Jf5"
},
"source": [
"The majority of passengers were in the \"third\" class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dgpJVeCq5Fgd"
},
"outputs": [],
"source": [
"dftrain['class'].value_counts().plot(kind='barh')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FXJhGGL85TLp"
},
"source": [
"Females have a much higher chance of surviving versus males. This is clearly a predictive feature for the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lSZYa7c45Ttt"
},
"outputs": [],
"source": [
"pd.concat([dftrain, y_train], axis=1).groupby('sex').survived.mean().plot(kind='barh').set_xlabel('% survive')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qCHvgeorEsHa"
},
"source": [
"## Feature Engineering for the Model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Dhcq8Ds4mCtm"
},
"source": [
"> Warning: The tf.feature_columns module described in this tutorial is not recommended for new code. Keras preprocessing layers cover this functionality, for migration instructions see the [Migrating feature columns guide](https://www.tensorflow.org/guide/migrate/migrating_feature_columns). The tf.feature_columns module was designed for use with TF1 Estimators. It does fall under our [compatibility guarantees](https://tensorflow.org/guide/versions), but will receive no fixes other than security vulnerabilities."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VqDKQLZn8L-B"
},
"source": [
"Estimators use a system called [feature columns](https://www.tensorflow.org/tutorials/structured_data/feature_columns) to describe how the model should interpret each of the raw input features. An Estimator expects a vector of numeric inputs, and *feature columns* describe how the model should convert each feature.\n",
"\n",
"Selecting and crafting the right set of feature columns is key to learning an effective model. A feature column can be either one of the raw inputs in the original features `dict` (a *base feature column*), or any new columns created using transformations defined over one or multiple base columns (a *derived feature columns*).\n",
"\n",
"The linear estimator uses both numeric and categorical features. Feature columns work with all TensorFlow estimators and their purpose is to define the features used for modeling. Additionally, they provide some feature engineering capabilities like one-hot-encoding, normalization, and bucketization."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "puZFOhTDkblt"
},
"source": [
"### Base Feature Columns"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GpveXYSsADS6"
},
"outputs": [],
"source": [
"CATEGORICAL_COLUMNS = ['sex', 'n_siblings_spouses', 'parch', 'class', 'deck',\n",
" 'embark_town', 'alone']\n",
"NUMERIC_COLUMNS = ['age', 'fare']\n",
"\n",
"feature_columns = []\n",
"for feature_name in CATEGORICAL_COLUMNS:\n",
" vocabulary = dftrain[feature_name].unique()\n",
" feature_columns.append(tf.feature_column.categorical_column_with_vocabulary_list(feature_name, vocabulary))\n",
"\n",
"for feature_name in NUMERIC_COLUMNS:\n",
" feature_columns.append(tf.feature_column.numeric_column(feature_name, dtype=tf.float32))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gt8HMtwOh9lJ"
},
"source": [
"The `input_function` specifies how data is converted to a `tf.data.Dataset` that feeds the input pipeline in a streaming fashion. `tf.data.Dataset` can take in multiple sources such as a dataframe, a csv-formatted file, and more."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qVtrIHFnAe7w"
},
"outputs": [],
"source": [
"def make_input_fn(data_df, label_df, num_epochs=10, shuffle=True, batch_size=32):\n",
" def input_function():\n",
" ds = tf.data.Dataset.from_tensor_slices((dict(data_df), label_df))\n",
" if shuffle:\n",
" ds = ds.shuffle(1000)\n",
" ds = ds.batch(batch_size).repeat(num_epochs)\n",
" return ds\n",
" return input_function\n",
"\n",
"train_input_fn = make_input_fn(dftrain, y_train)\n",
"eval_input_fn = make_input_fn(dfeval, y_eval, num_epochs=1, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P7UMVkQnkrgb"
},
"source": [
"You can inspect the dataset:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8ZcG_3KiCb1M"
},
"outputs": [],
"source": [
"ds = make_input_fn(dftrain, y_train, batch_size=10)()\n",
"for feature_batch, label_batch in ds.take(1):\n",
" print('Some feature keys:', list(feature_batch.keys()))\n",
" print()\n",
" print('A batch of class:', feature_batch['class'].numpy())\n",
" print()\n",
" print('A batch of Labels:', label_batch.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lMNBMyodjlW3"
},
"source": [
"You can also inspect the result of a specific feature column using the `tf.keras.layers.DenseFeatures` layer:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IMjlmbPlDmkB"
},
"outputs": [],
"source": [
"age_column = feature_columns[7]\n",
"tf.keras.layers.DenseFeatures([age_column])(feature_batch).numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f4zrAdCIjr3s"
},
"source": [
"`DenseFeatures` only accepts dense tensors, to inspect a categorical column you need to transform that to a indicator column first:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1VXmXFTSFEvv"
},
"outputs": [],
"source": [
"gender_column = feature_columns[0]\n",
"tf.keras.layers.DenseFeatures([tf.feature_column.indicator_column(gender_column)])(feature_batch).numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MEp59g5UkHYY"
},
"source": [
"After adding all the base features to the model, let's train the model. Training a model is just a single command using the `tf.estimator` API:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aGXjdnqqdgIs"
},
"outputs": [],
"source": [
"linear_est = tf.estimator.LinearClassifier(feature_columns=feature_columns)\n",
"linear_est.train(train_input_fn)\n",
"result = linear_est.evaluate(eval_input_fn)\n",
"\n",
"clear_output()\n",
"print(result)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3tOan4hDsG6d"
},
"source": [
"### Derived Feature Columns"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NOG2FSTHlAMu"
},
"source": [
"Now you reached an accuracy of 75%. Using each base feature column separately may not be enough to explain the data. For example, the correlation between age and the label may be different for different gender. Therefore, if you only learn a single model weight for `gender=\"Male\"` and `gender=\"Female\"`, you won't capture every age-gender combination (e.g. distinguishing between `gender=\"Male\"` AND `age=\"30\"` AND `gender=\"Male\"` AND `age=\"40\"`).\n",
"\n",
"To learn the differences between different feature combinations, you can add *crossed feature columns* to the model (you can also bucketize age column before the cross column):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AM-RsDzNfGlu"
},
"outputs": [],
"source": [
"age_x_gender = tf.feature_column.crossed_column(['age', 'sex'], hash_bucket_size=100)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DqDFyPKQmGTN"
},
"source": [
"After adding the combination feature to the model, let's train the model again:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "s8FV9oPQfS-g"
},
"outputs": [],
"source": [
"derived_feature_columns = [age_x_gender]\n",
"linear_est = tf.estimator.LinearClassifier(feature_columns=feature_columns+derived_feature_columns)\n",
"linear_est.train(train_input_fn)\n",
"result = linear_est.evaluate(eval_input_fn)\n",
"\n",
"clear_output()\n",
"print(result)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rwfdZj7ImLwb"
},
"source": [
"It now achieves an accuracy of 77.6%, which is slightly better than only trained in base features. You can try using more features and transformations to see if you can do better!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8_eyb9d-ncjH"
},
"source": [
"Now you can use the train model to make predictions on a passenger from the evaluation set. TensorFlow models are optimized to make predictions on a batch, or collection, of examples at once. Earlier, the `eval_input_fn` was defined using the entire evaluation set."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wiScyBcef6Dq"
},
"outputs": [],
"source": [
"pred_dicts = list(linear_est.predict(eval_input_fn))\n",
"probs = pd.Series([pred['probabilities'][1] for pred in pred_dicts])\n",
"\n",
"probs.plot(kind='hist', bins=20, title='predicted probabilities')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UEHRCd4sqrLs"
},
"source": [
"Finally, look at the receiver operating characteristic (ROC) of the results, which will give us a better idea of the tradeoff between the true positive rate and false positive rate."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kqEjsezIokIe"
},
"outputs": [],
"source": [
"from sklearn.metrics import roc_curve\n",
"from matplotlib import pyplot as plt\n",
"\n",
"fpr, tpr, _ = roc_curve(y_eval, probs)\n",
"plt.plot(fpr, tpr)\n",
"plt.title('ROC curve')\n",
"plt.xlabel('false positive rate')\n",
"plt.ylabel('true positive rate')\n",
"plt.xlim(0,)\n",
"plt.ylim(0,)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "linear.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}