import tensorflow as tf
from tensorflow.contrib.timeseries.python.timeseries import model as ts_model
class _LSTMModel(ts_model.SequentialTimeSeriesModel):
def __init__(self, num_units, num_features, dtype=tf.float32):
"""Initialize/configure the model object.
Note that we do not start graph building here. Rather, this object is a
configurable factory for TensorFlow graphs which are run by an Estimator.
Args:
num_units: The number of units in the model's LSTMCell.
num_features: The dimensionality of the time series (features per
timestep).
dtype: The floating point data type to use.
"""
super(_LSTMModel, self).__init__(
# Pre-register the metrics we'll be outputting (just a mean here).
train_output_names=["mean"],
predict_output_names=["mean"],
num_features=num_features,
dtype=dtype)
self._num_units = num_units
# Filled in by initialize_graph()
self._lstm_cell = None
self._lstm_cell_run = None
self._predict_from_lstm_output = None
def initialize_graph(self, input_statistics):
"""Save templates for components, which can then be used repeatedly.
This method is called every time a new graph is created. It's safe to start
adding ops to the current default graph here, but the graph should be
constructed from scratch.
Args:
input_statistics: A math_utils.InputStatistics object.
"""
super(_LSTMModel, self).initialize_graph(input_statistics=input_statistics)
self._lstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=self._num_units)
# Create templates so we don't have to worry about variable reuse.
self._lstm_cell_run = tf.make_template(
name_="lstm_cell",
func_=self._lstm_cell,
create_scope_now_=True)
# Transforms LSTM output into mean predictions.
self._predict_from_lstm_output = tf.make_template(
name_="predict_from_lstm_output",
func_=lambda inputs: tf.layers.dense(inputs=inputs, units=self.num_features),
create_scope_now_=True)
def get_start_state(self):
"""Return initial state for the time series model."""
return (
# Keeps track of the time associated with this state for error checking.
tf.zeros([], dtype=tf.int64),
# The previous observation or prediction.
tf.zeros([self.num_features], dtype=self.dtype),
# The state of the RNNCell (batch dimension removed since this parent
# class will broadcast).
[tf.squeeze(state_element, axis=0)
for state_element
in self._lstm_cell.zero_state(batch_size=1, dtype=self.dtype)])
def _transform(self, data):
"""Normalize data based on input statistics to encourage stable training."""
mean, variance = self._input_statistics.overall_feature_moments
return (data - mean) / variance
def _de_transform(self, data):
"""Transform data back to the input scale."""
mean, variance = self._input_statistics.overall_feature_moments
return data * variance + mean
def _filtering_step(self, current_times, current_values, state, predictions):
"""Update model state based on observations.
Note that we don't do much here aside from computing a loss. In this case
it's easier to update the RNN state in _prediction_step, since that covers
running the RNN both on observations (from this method) and our own
predictions. This distinction can be important for probabilistic models,
where repeatedly predicting without filtering should lead to low-confidence
predictions.
Args:
current_times: A [batch size] integer Tensor.
current_values: A [batch size, self.num_features] floating point Tensor
with new observations.
state: The model's state tuple.
predictions: The output of the previous `_prediction_step`.
Returns:
A tuple of new state and a predictions dictionary updated to include a
loss (note that we could also return other measures of goodness of fit,
although only "loss" will be optimized).
"""
state_from_time, prediction, lstm_state = state
with tf.control_dependencies(
[tf.assert_equal(current_times, state_from_time)]):
transformed_values = self._transform(current_values)
# Use mean squared error across features for the loss.
predictions["loss"] = tf.reduce_mean(
(prediction - transformed_values) ** 2, axis=-1)
# Keep track of the new observation in model state. It won't be run
# through the LSTM until the next _imputation_step.
new_state_tuple = (current_times, transformed_values, lstm_state)
return (new_state_tuple, predictions)
def _prediction_step(self, current_times, state):
"""Advance the RNN state using a previous observation or prediction."""
_, previous_observation_or_prediction, lstm_state = state
lstm_output, new_lstm_state = self._lstm_cell_run(
inputs=previous_observation_or_prediction, state=lstm_state)
next_prediction = self._predict_from_lstm_output(lstm_output)
new_state_tuple = (current_times, next_prediction, new_lstm_state)
return new_state_tuple, {"mean": self._de_transform(next_prediction)}
def _imputation_step(self, current_times, state):
"""Advance model state across a gap."""
# Does not do anything special if we're jumping across a gap. More advanced
# models, especially probabilistic ones, would want a special case that
# depends on the gap size.
return state
def _exogenous_input_step(
self, current_times, current_exogenous_regressors, state):
"""Update model state based on exogenous regressors."""
raise NotImplementedError(
"Exogenous inputs are not implemented for this example.")