我正在研究Tensorflow,但模型拟合遇到问题

我开始通过跟随导师来通过互联网视频研究tensorflow,但是我遇到了麻烦(我严格按照视频中的代码进行操作,但是我的显示错误)

我的代码是这个。

import tensorflow as tf
from tensorflow.keras import layers


from tensorflow.keras import datasets

(train_x, train_y), (test_x, test_y) = datasets.mnist.load_data()

inputs = layers.Input((28, 28, 1))
net = layers.Conv2D(32, (3, 3), padding='SAME')(inputs)
net = layers.Activation('relu')(net)
net = layers.Conv2D(32, (3, 3), padding='SAME')(net)
net = layers.Activation('relu')(net)
net = layers.MaxPooling2D(pool_size=(2, 2))(net)
net = layers.Dropout(0.25)(net)

net = layers.Conv2D(64, (3, 3), padding='SAME')(net)
net = layers.Activation('relu')(net)
net = layers.Conv2D(64, (3, 3), padding='SAME')(net)
net = layers.Activation('relu')(net)
net = layers.MaxPooling2D(pool_size=(2, 2))(net)
net = layers.Dropout(0.25)(net)

net = layers.Flatten()(net)
net = layers.Dense(512)(net)
net = layers.Activation('relu')(net)
net = layers.Dropout(0.5)(net)
net = layers.Dense(10)(net)  # num_classes
net = layers.Activation('softmax')(net)

model = tf.keras.Model(inputs=inputs, outputs=net, name='Basic_CNN')

model.summary()



loss_fun = tf.keras.losses.sparse_categorical_crossentropy 

metrics = [tf.keras.metrics.Accuracy()]

optm = tf.keras.optimizers.Adam()

model.compile(optimizer=tf.keras.optimizers.Adam(), 
          loss='sparse_categorical_crossentropy', 
          metrics=[tf.keras.metrics.Accuracy()])

train_x.shape, train_y.shape

test_x.shape, test_y.shape

import numpy as np

train_x = train_x[..., tf.newaxis]
test_x = test_x[..., tf.newaxis]

train_x.shape

test_x.shape

np.min(train_x), np.max(train_x)

train_x = train_x / 255.
test_x = test_x / 255.

np.min(train_x), np.max(train_x)

我下面的模型拟合代码是这个。

num_epochs = 10
batch_size = 32

train_y.shape

model.fit(train_x,train_y, 
      batch_size=32, 
      shuffle=True, 
      epochs=num_epochs)

当我运行这段代码时,我得到了这个错误。 大声笑。

Train on 60000 samples
Epoch 1/10

32/60000 [.....................]-预计到达时间:3:50

ValueError                                Traceback (most recent call last)
<ipython-input-4-d49e2292bdcf> in <module>
  7           batch_size=32,
  8           shuffle=True,

----> 9个纪元= num_epochs)

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
726         max_queue_size=max_queue_size,
727         workers=workers,
--> 728         use_multiprocessing=use_multiprocessing)
729 
730   def evaluate(self,

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, **kwargs)
322                 mode=ModeKeys.TRAIN,
323                 training_context=training_context,
--> 324                 total_epochs=epochs)
325             cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)
326 

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py in run_one_epoch(model, iterator, execution_function, dataset_size, batch_size, strategy, steps_per_epoch, num_samples, mode, training_context, total_epochs)
121         step=step, mode=mode, size=current_batch_size) as batch_logs:
122       try:
--> 123         batch_outs = execution_function(iterator)
124       except (StopIteration, errors.OutOfRangeError):
125         # TODO(kaftan): File bug about tf function and errors.OutOfRangeError?

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\engine\training_v2_utils.py in execution_function(input_fn)
 84     # `numpy` translates Tensors to values in Eager mode.
 85     return nest.map_structure(_non_none_constant_value,
---> 86                               distributed_function(input_fn))
 87 
 88   return execution_function

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\eager\def_function.py in __call__(self, *args, **kwds)
455 
456     tracing_count = self._get_tracing_count()
--> 457     result = self._call(*args, **kwds)
458     if tracing_count == self._get_tracing_count():
459       self._call_counter.called_without_tracing()

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\eager\def_function.py in _call(self, *args, **kwds)
501       # This is the first call of __call__, so we have to initialize.
502       initializer_map = object_identity.ObjectIdentityDictionary()
--> 503       self._initialize(args, kwds, add_initializers_to=initializer_map)
504     finally:
505       # At this point we know that the initialization is complete (or less

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\eager\def_function.py in _initialize(self, args, kwds, add_initializers_to)
406     self._concrete_stateful_fn = (
407         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 408             *args, **kwds))
409 
410     def invalid_creator_scope(*unused_args, **unused_kwds):

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\eager\function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
1846     if self.input_signature:
1847       args, kwargs = None, None
-> 1848     graph_function, _, _ = self._maybe_define_function(args, kwargs)
1849     return graph_function
1850 

 ~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\eager\function.py in _maybe_define_function(self, args, kwargs)
2148         graph_function = self._function_cache.primary.get(cache_key, None)
2149         if graph_function is None:
-> 2150           graph_function = self._create_graph_function(args, kwargs)
2151           self._function_cache.primary[cache_key] = graph_function
2152         return graph_function, args, kwargs

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\eager\function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
2039             arg_names=arg_names,
2040             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2041             capture_by_value=self._capture_by_value),
2042         self._function_attributes,
2043         # Tell the ConcreteFunction to clean up its graph once it goes out of

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\framework\func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
913                                           converted_func)
914 
--> 915       func_outputs = python_func(*func_args, **func_kwargs)
916 
917       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\eager\def_function.py in wrapped_fn(*args, **kwds)
356         # __wrapped__ allows AutoGraph to swap in a converted function. We give
357         # the function a weak reference to itself to avoid a reference cycle.
--> 358         return weak_wrapped_fn().__wrapped__(*args, **kwds)
359     weak_wrapped_fn = weakref.ref(wrapped_fn)
360 

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\engine\training_v2_utils.py in distributed_function(input_iterator)
 71     strategy = distribution_strategy_context.get_strategy()
 72     outputs = strategy.experimental_run_v2(
---> 73         per_replica_function, args=(model, x, y, sample_weights))
 74     # Out of PerReplica outputs reduce or pick values to return.
 75     all_outputs = dist_utils.unwrap_output_dict(

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\distribute\distribute_lib.py in experimental_run_v2(self, fn, args, kwargs)
758       fn = autograph.tf_convert(fn, ag_ctx.control_status_ctx(),
759                                 convert_by_default=False)
--> 760       return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
761 
762   def reduce(self, reduce_op, value, axis):

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\distribute\distribute_lib.py in call_for_each_replica(self, fn, args, kwargs)
1785       kwargs = {}
1786     with self._container_strategy().scope():
-> 1787       return self._call_for_each_replica(fn, args, kwargs)
1788 
1789   def _call_for_each_replica(self, fn, args, kwargs):

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\distribute\distribute_lib.py in _call_for_each_replica(self, fn, args, kwargs)
2130         self._container_strategy(),
2131         replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
-> 2132       return fn(*args, **kwargs)
2133 
2134   def _reduce_to(self, reduce_op, value, destinations):

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\autograph\impl\api.py in wrapper(*args, **kwargs)
290   def wrapper(*args, **kwargs):
291     with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
--> 292       return func(*args, **kwargs)
293 
294   if inspect.isfunction(func) or inspect.ismethod(func):

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\engine\training_v2_utils.py in train_on_batch(model, x, y, sample_weight, class_weight, reset_metrics)
262       y,
263       sample_weights=sample_weights,
--> 264       output_loss_metrics=model._output_loss_metrics)
265 
266   if reset_metrics:

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\engine\training_eager.py in train_on_batch(model, inputs, targets, sample_weights, output_loss_metrics)
313     outs = [outs]
314   metrics_results = _eager_metrics_fn(
--> 315       model, outs, targets, sample_weights=sample_weights, masks=masks)
316   total_loss = nest.flatten(total_loss)
317   return {'total_loss': total_loss,

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\engine\training_eager.py in _eager_metrics_fn(model, outputs, targets, sample_weights, masks)
 72         masks=masks,
 73         return_weighted_and_unweighted_metrics=True,
---> 74         skip_target_masks=model._prepare_skip_target_masks())
 75 
 76   # Add metric results from the `add_metric` metrics.

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\engine\training.py in _handle_metrics(self, outputs, targets, skip_target_masks, sample_weights, masks, return_weighted_metrics, return_weighted_and_unweighted_metrics)
2061           metric_results.extend(
2062               self._handle_per_output_metrics(self._per_output_metrics[i],
-> 2063                                               target, output, output_mask))
2064         if return_weighted_and_unweighted_metrics or return_weighted_metrics:
2065           metric_results.extend(

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\engine\training.py in _handle_per_output_metrics(self, metrics_dict, y_true, y_pred, mask, weights)
2012       with K.name_scope(metric_name):
2013         metric_result = training_utils.call_metric_function(
-> 2014             metric_fn, y_true, y_pred, weights=weights, mask=mask)
2015         metric_results.append(metric_result)
2016     return metric_results

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\engine\training_utils.py in call_metric_function(metric_fn, y_true, y_pred, weights, mask)
1065 
1066   if y_pred is not None:
-> 1067     return metric_fn(y_true, y_pred, sample_weight=weights)
1068   # `Mean` metric only takes a single value.
1069   return metric_fn(y_true, sample_weight=weights)

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\metrics.py in __call__(self, *args, **kwargs)
191     from tensorflow.python.keras.distribute import distributed_training_utils  # pylint:disable=g-import-not-at-top
192     return distributed_training_utils.call_replica_local_fn(
--> 193         replica_local_fn, *args, **kwargs)
194 
195   @property

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\distribute\distributed_training_utils.py in call_replica_local_fn(fn, *args, **kwargs)
1133     with strategy.scope():
1134       return strategy.extended.call_for_each_replica(fn, args, kwargs)
-> 1135   return fn(*args, **kwargs)
1136 
1137 

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\metrics.py in replica_local_fn(*args, **kwargs)
174     def replica_local_fn(*args, **kwargs):
175       """Updates the state of the metric in a replica-local context."""
--> 176       update_op = self.update_state(*args, **kwargs)  # pylint: disable=not-callable
177       with ops.control_dependencies([update_op]):
178         result_t = self.result()  # pylint: disable=not-callable

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\utils\metrics_utils.py in decorated(metric_obj, *args, **kwargs)
 73 
 74     with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs):
---> 75       update_op = update_state_fn(*args, **kwargs)
 76     if update_op is not None:  # update_op will be None in eager execution.
 77       metric_obj.add_update(update_op)

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\metrics.py in update_state(self, y_true, y_pred, sample_weight)
579         y_pred, y_true)
580 
--> 581     matches = self._fn(y_true, y_pred, **self._fn_kwargs)
582     return super(MeanMetricWrapper, self).update_state(
583         matches, sample_weight=sample_weight)

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\keras\metrics.py in accuracy(y_true, y_pred)
2748       metrics_utils.ragged_assert_compatible_and_get_flat_values(
2749           [y_pred, y_true])
-> 2750   y_pred.shape.assert_is_compatible_with(y_true.shape)
2751   if y_true.dtype != y_pred.dtype:
2752     y_pred = math_ops.cast(y_pred, y_true.dtype)

~\Anaconda3\envs\gp\lib\site-packages\tensorflow_core\python\framework\tensor_shape.py in assert_is_compatible_with(self, other)
1113     """
1114     if not self.is_compatible_with(other):
-> 1115       raise ValueError("Shapes %s and %s are incompatible" % (self, other))
1116 
1117   def most_specific_compatible_shape(self, other):

ValueError: Shapes (32, 10) and (32, 1) are incompatible
评论