# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
# pylint: disable=redefined-outer-name
# pylint: disable=redefined-builtin
# pylint: disable=g-classes-have-attributes
"""Keras backend API."""
import tensorflow.compat.v2 as tf
import collections
import itertools
import json
import os
import sys
import threading
import warnings
import weakref
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager.context import get_config
from tensorflow.python.framework import config
from keras import backend_config
from keras.distribute import distribute_coordinator_utils as dc
from keras.engine import keras_tensor
from keras.utils import control_flow_util
from keras.utils import object_identity
from keras.utils import tf_contextlib
from keras.utils import tf_inspect
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls
py_all = all
py_sum = sum
py_any = any
# INTERNAL UTILS
# The internal graph maintained by Keras and used by the symbolic Keras APIs
# while executing eagerly (such as the functional API for model-building).
# This is thread-local to allow building separate models in different threads
# concurrently, but comes at the cost of not being able to build one model
# across threads.
_GRAPH = threading.local()
# A graph which is used for constructing functions in eager mode.
_CURRENT_SCRATCH_GRAPH = threading.local()
# This is a thread local object that will hold the default internal TF session
# used by Keras. It can be set manually via `set_session(sess)`.
_SESSION = threading.local()
# A global dictionary mapping graph objects to an index of counters used
# for various layer/optimizer names in each graph.
# Allows to give unique autogenerated names to layers, in a graph-specific way.
PER_GRAPH_OBJECT_NAME_UIDS = weakref.WeakKeyDictionary()
# A global set tracking what object names have been seen so far.
# Optionally used as an avoid-list when generating names
OBSERVED_NAMES = set()
# _DUMMY_EAGER_GRAPH.key is used as a key in _GRAPH_LEARNING_PHASES.
# We keep a separate reference to it to make sure it does not get removed from
# _GRAPH_LEARNING_PHASES.
# _DummyEagerGraph inherits from threading.local to make its `key` attribute
# thread local. This is needed to make set_learning_phase affect only the
# current thread during eager execution (see b/123096885 for more details).
class _DummyEagerGraph(threading.local):
"""_DummyEagerGraph provides a thread local `key` attribute.
We can't use threading.local directly, i.e. without subclassing, because
gevent monkey patches threading.local and its version does not support
weak references.
"""
class _WeakReferencableClass:
"""This dummy class is needed for two reasons.
- We need something that supports weak references. Basic types like string
and ints don't.
- We need something whose hash and equality are based on object identity
to make sure they are treated as different keys to _GRAPH_LEARNING_PHASES.
An empty Python class satisfies both of these requirements.
"""
pass
def __init__(self):
# Constructors for classes subclassing threading.local run once
# per thread accessing something in the class. Thus, each thread will
# get a different key.
super(_DummyEagerGraph, self).__init__()
self.key = _DummyEagerGraph._WeakReferencableClass()
self.learning_phase_is_set = False
_DUMMY_EAGER_GRAPH = _DummyEagerGraph()
# This boolean flag can be set to True to leave variable initialization
# up to the user.
# Change its value via `manual_variable_initialization(value)`.
_MANUAL_VAR_INIT = False
# This list holds the available devices.
# It is populated when `_get_available_gpus()` is called for the first time.
# We assume our devices don't change henceforth.
_LOCAL_DEVICES = None
# The below functions are kept accessible from backend for compatibility.
epsilon = backend_config.epsilon
floatx = backend_config.floatx
image_data_format = backend_config.image_data_format
set_epsilon = backend_config.set_epsilon
set_floatx = backend_config.set_floatx
set_image_data_format = backend_config.set_image_data_format
@keras_export('keras.backend.backend')
@doc_controls.do_not_generate_docs
def backend():
"""Publicly accessible method for determining the current backend.
Only exists for API compatibility with multi-backend Keras.
Returns:
The string "tensorflow".
"""
return 'tensorflow'
@keras_export('keras.backend.cast_to_floatx')
@tf.__internal__.dispatch.add_dispatch_support
@doc_controls.do_not_generate_docs
def cast_to_floatx(x):
"""Cast a Numpy array to the default Keras float type.
Args:
x: Numpy array or TensorFlow tensor.
Returns:
The same array (Numpy array if `x` was a Numpy array, or TensorFlow tensor
if `x` was a tensor), cast to its new type.
Example:
>>> tf.keras.backend.floatx()
'float32'
>>> arr = np.array([1.0, 2.0], dtype='float64')
>>> arr.dtype
dtype('float64')
>>> new_arr = cast_to_floatx(arr)
>>> new_arr
array([1., 2.], dtype=float32)
>>> new_arr.dtype
dtype('float32')
"""
if isinstance(x, (tf.Tensor,
tf.Variable,
tf.SparseTensor)):
return tf.cast(x, dtype=floatx())
return np.asarray(x, dtype=floatx())
@keras_export('keras.backend.get_uid')
def get_uid(prefix=''):
"""Associates a string prefix with an integer counter in a TensorFlow graph.
Args:
prefix: String prefix to index.
Returns:
Unique integer ID.
Example:
>>> get_uid('dense')
1
>>> get_uid('dense')
2
"""
graph = get_graph()
if graph not in PER_GRAPH_OBJECT_NAME_UIDS:
PER_GRAPH_OBJECT_NAME_UIDS[graph] = collections.defaultdict(int)
layer_name_uids = PER_GRAPH_OBJECT_NAME_UIDS[graph]
layer_name_uids[prefix] += 1
return layer_name_uids[prefix]
@keras_export('keras.backend.reset_uids')
def reset_uids():
"""Resets graph identifiers.
"""
PER_GRAPH_OBJECT_NAME_UIDS.clear()
OBSERVED_NAMES.clear()
@keras_export('keras.backend.clear_session')
def clear_session():
"""Resets all state generated by Keras.
Keras manages a global state, which it uses to implement the Functional
model-building API and to uniquify autogenerated layer names.
If you are creating many models in a loop, this global state will consume
an increasing amount of memory over time, and you may want to clear it.
Calling `clear_session()` releases the global state: this helps avoid clutter
from old models and layers, especially when memory is limited.
Example 1: calling `clear_session()` when creating models in a loop
```python
for _ in range(100):
# Without `clear_session()`, each iteration of this loop will
# slightly increase the size of the global state managed by Keras
model = tf.keras.Sequential([tf.keras.layers.Dense(10) for _ in range(10)])
for _ in range(100):
# With `clear_session()` called at the beginning,
# Keras starts with a blank state at each iteration
# and memory consumption is constant over time.
tf.keras.backend.clear_session()
model = tf.keras.Sequential([tf.keras.layers.Dense(10) for _ in