CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 74.8k
Description
Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Debian 10
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary): Binary
- TensorFlow version (use command below): v2.4.0-0-g582c8d236cb 2.4.0
- Python version: 3.7.9
- Bazel version (if compiling from source): n/a
- GCC/Compiler version (if compiling from source): n/a
- CUDA/cuDNN version: n/a
- GPU model and memory: n/a
You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with:
- TF 1.0:
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
- TF 2.0:
python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"
Describe the current behavior
Running a simple training process with MultiWorkerMirroredStrategy fails with TypeError: can't pickle _thread.lock objects
.
Describe the expected behavior
The training should proceed without errors.
Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
The example needs to run in a distributed environment to reproduce the issue, so save the script in a file and run it in 3 different terminals.
TF_CONFIG='{"cluster": {"chief": ["localhost:2222"], "worker": ["localhost:2223", "localhost:2224"]}, "task": {"type": "chief", "index": 0}}' python script.py
TF_CONFIG='{"cluster": {"chief": ["localhost:2222"], "worker": ["localhost:2223", "localhost:2224"]}, "task": {"type": "worker", "index": 0}}' python script.py
TF_CONFIG='{"cluster": {"chief": ["localhost:2222"], "worker": ["localhost:2223", "localhost:2224"]}, "task": {"type": "worker", "index": 1}}' python script.py
import tensorflow as tf
import tensorflow_datasets as tfds
buffer_size = 10000
batch_size = 64
learning_rate = 1e-4
def input_fn(mode, input_context=None):
tfds.disable_progress_bar()
datasets, _ = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_dataset = (
datasets['train']
if mode == tf.estimator.ModeKeys.TRAIN else datasets['test'])
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
if input_context:
mnist_dataset = mnist_dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
return mnist_dataset.map(scale).cache().shuffle(buffer_size).batch(batch_size)
def model_fn(features, labels, mode):
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
logits = model(features, training=False)
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = {'logits': logits}
return tf.estimator.EstimatorSpec(labels=labels, predictions=predictions)
optimizer = tf.compat.v1.train.GradientDescentOptimizer(
learning_rate=learning_rate)
loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels,
logits)
loss = tf.reduce_sum(loss) * (1. / batch_size)
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(mode, loss=loss)
logging_hook = tf.estimator.LoggingTensorHook({'loss': loss}, every_n_iter=10)
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
training_hooks=[logging_hook],
train_op=optimizer.minimize(
loss, tf.compat.v1.train.get_or_create_global_step()))
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
config = tf.estimator.RunConfig(train_distribute=strategy)
classifier = tf.estimator.Estimator(
model_fn=model_fn, model_dir='/tmp/multiworker', config=config)
tf.estimator.train_and_evaluate(
classifier,
train_spec=tf.estimator.TrainSpec(input_fn=input_fn),
eval_spec=tf.estimator.EvalSpec(input_fn=input_fn))
Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.
Full logs:
TF_CONFIG='{"cluster": {"chief": ["localhost:2222"], "worker": ["localhost:2223", "localhost:2224"]}, "task": {"type": "worker", "index": 1}}' python script.py
WARNING:tensorflow:From script.py:68: _CollectiveAllReduceStrategyExperimental.__init__ (from tensorflow.python.distribute.collective_all_reduce_strategy) is deprecated and will be removed in a future version.
Instructions for updating:
use distribute.MultiWorkerMirroredStrategy instead
2021-01-20 18:24:44.477611: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-01-20 18:24:44.479538: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.
2021-01-20 18:24:44.491607: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job chief -> {0 -> localhost:2222}
2021-01-20 18:24:44.491654: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job worker -> {0 -> localhost:2223, 1 -> localhost:2224}
2021-01-20 18:24:44.492211: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:411] Started server with target: grpc://localhost:2224
Traceback (most recent call last):
File "script.py", line 73, in <module>
model_fn=model_fn, model_dir='/tmp/multiworker', config=config)
File "/opt/conda/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 183, in __init__
config, model_dir)
File "/opt/conda/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1832, in maybe_overwrite_model_dir_and_session_config
config = run_config.RunConfig.replace(config, session_config=session_config)
File "/opt/conda/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/run_config.py", line 923, in replace
copy.deepcopy(self),
File "/opt/conda/lib/python3.7/copy.py", line 180, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/opt/conda/lib/python3.7/copy.py", line 281, in _reconstruct
state = deepcopy(state, memo)
File "/opt/conda/lib/python3.7/copy.py", line 150, in deepcopy
y = copier(x, memo)
File "/opt/conda/lib/python3.7/copy.py", line 241, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/opt/conda/lib/python3.7/copy.py", line 161, in deepcopy
y = copier(memo)
File "/opt/conda/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1542, in __deepcopy__
setattr(result, k, copy.deepcopy(v, memo))
File "/opt/conda/lib/python3.7/copy.py", line 180, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/opt/conda/lib/python3.7/copy.py", line 281, in _reconstruct
state = deepcopy(state, memo)
File "/opt/conda/lib/python3.7/copy.py", line 150, in deepcopy
y = copier(x, memo)
File "/opt/conda/lib/python3.7/copy.py", line 241, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/opt/conda/lib/python3.7/copy.py", line 180, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/opt/conda/lib/python3.7/copy.py", line 281, in _reconstruct
state = deepcopy(state, memo)
File "/opt/conda/lib/python3.7/copy.py", line 150, in deepcopy
y = copier(x, memo)
File "/opt/conda/lib/python3.7/copy.py", line 241, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/opt/conda/lib/python3.7/copy.py", line 180, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/opt/conda/lib/python3.7/copy.py", line 281, in _reconstruct
state = deepcopy(state, memo)
File "/opt/conda/lib/python3.7/copy.py", line 150, in deepcopy
y = copier(x, memo)
File "/opt/conda/lib/python3.7/copy.py", line 241, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/opt/conda/lib/python3.7/copy.py", line 169, in deepcopy
rv = reductor(4)
TypeError: can't pickle _thread.lock objects