diff --git a/Examples/LeNet-MNIST/main.swift b/Examples/LeNet-MNIST/main.swift index e5658701466..61a322c2730 100644 --- a/Examples/LeNet-MNIST/main.swift +++ b/Examples/LeNet-MNIST/main.swift @@ -18,6 +18,14 @@ import Datasets let epochCount = 12 let batchSize = 128 +// Until https://github.com/tensorflow/swift-models/issues/588 is fixed, default to the eager-mode +// device on macOS instead of X10. +#if os(macOS) + let device = Device.defaultTFEager +#else + let device = Device.defaultXLA +#endif + let dataset = MNIST(batchSize: batchSize) // The LeNet-5 model, equivalent to `LeNet` in `ImageClassificationModels`. var classifier = Sequential { @@ -30,65 +38,84 @@ var classifier = Sequential { Dense(inputSize: 120, outputSize: 84, activation: relu) Dense(inputSize: 84, outputSize: 10) } +classifier.move(to: device) -let optimizer = SGD(for: classifier, learningRate: 0.1) +var optimizer = SGD(for: classifier, learningRate: 0.1) +optimizer = SGD(copying: optimizer, to: device) print("Beginning training...") struct Statistics { - var correctGuessCount: Int = 0 - var totalGuessCount: Int = 0 - var totalLoss: Float = 0 + var correctGuessCount = Tensor(0, on: Device.default) + var totalGuessCount = Tensor(0, on: Device.default) + var totalLoss = Tensor(0, on: Device.default) var batches: Int = 0 + + var accuracy: Float { + Float(correctGuessCount.scalarized()) / Float(totalGuessCount.scalarized()) * 100 + } + + var averageLoss: Float { + totalLoss.scalarized() / Float(batches) + } + + init(on device: Device = Device.default) { + correctGuessCount = Tensor(0, on: device) + totalGuessCount = Tensor(0, on: device) + totalLoss = Tensor(0, on: device) + } + + mutating func update(logits: Tensor, labels: Tensor, loss: Tensor) { + let correct = logits.argmax(squeezingAxis: 1) .== labels + correctGuessCount += Tensor(correct).sum() + totalGuessCount += Int32(labels.shape[0]) + totalLoss += loss + batches += 1 + } } // The training loop. for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() { - var trainStats = Statistics() - var testStats = Statistics() + var trainStats = Statistics(on: device) + var testStats = Statistics(on: device) Context.local.learningPhase = .training for batch in epochBatches { - let (images, labels) = (batch.data, batch.label) + let (eagerImages, eagerLabels) = (batch.data, batch.label) + let images = Tensor(copying: eagerImages, to: device) + let labels = Tensor(copying: eagerLabels, to: device) // Compute the gradient with respect to the model. let 𝛁model = TensorFlow.gradient(at: classifier) { classifier -> Tensor in let ŷ = classifier(images) - let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== labels - trainStats.correctGuessCount += Int( - Tensor(correctPredictions).sum().scalarized()) - trainStats.totalGuessCount += images.shape[0] let loss = softmaxCrossEntropy(logits: ŷ, labels: labels) - trainStats.totalLoss += loss.scalarized() - trainStats.batches += 1 + trainStats.update(logits: ŷ, labels: labels, loss: loss) return loss } // Update the model's differentiable variables along the gradient vector. optimizer.update(&classifier, along: 𝛁model) + LazyTensorBarrier() } Context.local.learningPhase = .inference for batch in dataset.validation { - let (images, labels) = (batch.data, batch.label) + let (eagerImages, eagerLabels) = (batch.data, batch.label) + let images = Tensor(copying: eagerImages, to: device) + let labels = Tensor(copying: eagerLabels, to: device) // Compute loss on test set let ŷ = classifier(images) - let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== labels - testStats.correctGuessCount += Int(Tensor(correctPredictions).sum().scalarized()) - testStats.totalGuessCount += images.shape[0] let loss = softmaxCrossEntropy(logits: ŷ, labels: labels) - testStats.totalLoss += loss.scalarized() - testStats.batches += 1 + LazyTensorBarrier() + testStats.update(logits: ŷ, labels: labels, loss: loss) } - let trainAccuracy = Float(trainStats.correctGuessCount) / Float(trainStats.totalGuessCount) - let testAccuracy = Float(testStats.correctGuessCount) / Float(testStats.totalGuessCount) print( """ [Epoch \(epoch + 1)] \ - Training Loss: \(trainStats.totalLoss / Float(trainStats.batches)), \ + Training Loss: \(String(format: "%.3f", trainStats.averageLoss)), \ Training Accuracy: \(trainStats.correctGuessCount)/\(trainStats.totalGuessCount) \ - (\(trainAccuracy)), \ - Test Loss: \(testStats.totalLoss / Float(testStats.batches)), \ + (\(String(format: "%.1f", trainStats.accuracy))%), \ + Test Loss: \(String(format: "%.3f", testStats.averageLoss)), \ Test Accuracy: \(testStats.correctGuessCount)/\(testStats.totalGuessCount) \ - (\(testAccuracy)) + (\(String(format: "%.3f", testStats.accuracy))%) """) -} \ No newline at end of file +}