-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.cpp
71 lines (63 loc) · 2.4 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
//
// main.cpp
// Pytorch
//
// Created by 潘洪岩 on 2019/9/3.
// Copyright © 2019 潘洪岩. All rights reserved.
//
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <string>
#include <vector>
#include <torch/torch.h>
#include <functional>
#include "data/data.h"
#include "embedding/embedding.h"
std::string TRAINFILE("/Users/panhongyan/bert/sample_text.txt");
template <typename DataLoader>
void train(WordEmbedding& emb,torch::data::datasets::Options& options,DataLoader& loader,torch::optim::Optimizer& optimizer,size_t epoch,size_t data_size)
{
size_t index = 0;
emb.train();
float Loss = 0;
for (auto& batch : loader) {
auto data = batch.data.to(options.device);
auto targets = batch.target.to(options.device).view({-1});
auto output = emb.forward(data);
auto loss = torch::nll_loss(output, targets);
assert(!std::isnan(loss.template item<float>()));
emb.zero_grad();
loss.backward();
optimizer.step();
Loss += loss.template item<float>();
// std::cout<<index<<std::endl;
if (index++ % options.log_interval == 0) {
auto end = std::min(data_size, (index + 1) * options.train_batch_size);
std::cout << "Train Epoch: " << epoch << " " << end << "/" << data_size
<< "\tLoss: " << Loss / end << std::endl;
}
}
}
void test(WordEmbedding& wemb,Vocab& vocab,std::string& word)
{
long id=vocab.get_id(word);
}
int main() {
torch::manual_seed(1);
torch::data::datasets::Options opts;
auto corpus= torch::data::datasets::loadTrain(TRAINFILE);
torch::data::datasets::EmbeddingTextData text(corpus.data,corpus.vocab,opts);
auto train_text=text.map(torch::data::transforms::Stack<>());
auto train_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(std::move(train_text), torch::data::DataLoaderOptions().batch_size(opts.train_batch_size).workers(2));
auto wemb =std::make_shared<WordEmbedding>(corpus.vocab.size(),opts.dim);
wemb->to(opts.device);
torch::optim::Adam adam(wemb->parameters(),torch::optim::AdamOptions(0.01));
size_t data_size=corpus.vocab.tsize();
for (size_t epoch=0;epoch<opts.epoch;epoch++) {
train(*wemb, opts, *train_loader, adam, epoch, data_size);
std::cout << std::endl;
}
torch::save(wemb, "/Users/panhongyan/cbow/cke.pt");
return 0;
}