Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert LeNet-MNIST example to X10 #578

Merged
merged 2 commits into from
Jun 9, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 53 additions & 26 deletions Examples/LeNet-MNIST/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ import Datasets
let epochCount = 12
let batchSize = 128

// Until https://github.com/tensorflow/swift-models/issues/588 is fixed, default to the eager-mode
// device on macOS instead of X10.
#if os(macOS)
let device = Device.defaultTFEager
#else
let device = Device.defaultXLA
#endif

let dataset = MNIST(batchSize: batchSize)
// The LeNet-5 model, equivalent to `LeNet` in `ImageClassificationModels`.
var classifier = Sequential {
Expand All @@ -30,65 +38,84 @@ var classifier = Sequential {
Dense<Float>(inputSize: 120, outputSize: 84, activation: relu)
Dense<Float>(inputSize: 84, outputSize: 10)
}
classifier.move(to: device)

let optimizer = SGD(for: classifier, learningRate: 0.1)
var optimizer = SGD(for: classifier, learningRate: 0.1)
optimizer = SGD(copying: optimizer, to: device)

print("Beginning training...")

struct Statistics {
var correctGuessCount: Int = 0
var totalGuessCount: Int = 0
var totalLoss: Float = 0
var correctGuessCount = Tensor<Int32>(0, on: Device.default)
var totalGuessCount = Tensor<Int32>(0, on: Device.default)
var totalLoss = Tensor<Float>(0, on: Device.default)
var batches: Int = 0

var accuracy: Float {
Float(correctGuessCount.scalarized()) / Float(totalGuessCount.scalarized()) * 100
}

var averageLoss: Float {
totalLoss.scalarized() / Float(batches)
}

init(on device: Device = Device.default) {
correctGuessCount = Tensor<Int32>(0, on: device)
totalGuessCount = Tensor<Int32>(0, on: device)
totalLoss = Tensor<Float>(0, on: device)
}

mutating func update(logits: Tensor<Float>, labels: Tensor<Int32>, loss: Tensor<Float>) {
let correct = logits.argmax(squeezingAxis: 1) .== labels
correctGuessCount += Tensor<Int32>(correct).sum()
totalGuessCount += Int32(labels.shape[0])
totalLoss += loss
batches += 1
}
}

// The training loop.
for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() {
var trainStats = Statistics()
var testStats = Statistics()
var trainStats = Statistics(on: device)
var testStats = Statistics(on: device)

Context.local.learningPhase = .training
for batch in epochBatches {
let (images, labels) = (batch.data, batch.label)
let (eagerImages, eagerLabels) = (batch.data, batch.label)
let images = Tensor(copying: eagerImages, to: device)
let labels = Tensor(copying: eagerLabels, to: device)
// Compute the gradient with respect to the model.
let 𝛁model = TensorFlow.gradient(at: classifier) { classifier -> Tensor<Float> in
let ŷ = classifier(images)
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== labels
trainStats.correctGuessCount += Int(
Tensor<Int32>(correctPredictions).sum().scalarized())
trainStats.totalGuessCount += images.shape[0]
let loss = softmaxCrossEntropy(logits: ŷ, labels: labels)
trainStats.totalLoss += loss.scalarized()
trainStats.batches += 1
trainStats.update(logits: ŷ, labels: labels, loss: loss)
return loss
}
// Update the model's differentiable variables along the gradient vector.
optimizer.update(&classifier, along: 𝛁model)
LazyTensorBarrier()
}

Context.local.learningPhase = .inference
for batch in dataset.validation {
let (images, labels) = (batch.data, batch.label)
let (eagerImages, eagerLabels) = (batch.data, batch.label)
let images = Tensor(copying: eagerImages, to: device)
let labels = Tensor(copying: eagerLabels, to: device)
// Compute loss on test set
let ŷ = classifier(images)
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== labels
testStats.correctGuessCount += Int(Tensor<Int32>(correctPredictions).sum().scalarized())
testStats.totalGuessCount += images.shape[0]
let loss = softmaxCrossEntropy(logits: ŷ, labels: labels)
testStats.totalLoss += loss.scalarized()
testStats.batches += 1
LazyTensorBarrier()
testStats.update(logits: ŷ, labels: labels, loss: loss)
}

let trainAccuracy = Float(trainStats.correctGuessCount) / Float(trainStats.totalGuessCount)
let testAccuracy = Float(testStats.correctGuessCount) / Float(testStats.totalGuessCount)
print(
"""
[Epoch \(epoch + 1)] \
Training Loss: \(trainStats.totalLoss / Float(trainStats.batches)), \
Training Loss: \(String(format: "%.3f", trainStats.averageLoss)), \
Training Accuracy: \(trainStats.correctGuessCount)/\(trainStats.totalGuessCount) \
(\(trainAccuracy)), \
Test Loss: \(testStats.totalLoss / Float(testStats.batches)), \
(\(String(format: "%.1f", trainStats.accuracy))%), \
Test Loss: \(String(format: "%.3f", testStats.averageLoss)), \
Test Accuracy: \(testStats.correctGuessCount)/\(testStats.totalGuessCount) \
(\(testAccuracy))
(\(String(format: "%.3f", testStats.accuracy))%)
""")
}
}