diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 981d01dd7be..acc9723c183 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -550,8 +550,12 @@ def _GroupByDevices(self, saveables): """ per_device = collections.defaultdict(lambda: []) for saveable in saveables: - canonical_device = set( - pydev.canonical_name(spec.tensor.device) for spec in saveable.specs) + canonical_device = set() + for spec in saveable.specs: + device_name = pydev.canonical_name(spec.tensor.device) + device_spec = pydev.DeviceSpec.from_string(device_name) + device_spec.device_type = "CPU" + canonical_device.add(device_spec.to_string()) if len(canonical_device) != 1: raise ValueError("All tensors of a saveable object must be " "on the same device: %s" % saveable.name)