cifar.c 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. #include "darknet.h"
  2. void train_cifar(char *cfgfile, char *weightfile)
  3. {
  4. srand(time(0));
  5. float avg_loss = -1;
  6. char *base = basecfg(cfgfile);
  7. printf("%s\n", base);
  8. network *net = load_network(cfgfile, weightfile, 0);
  9. printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
  10. char *backup_directory = "/home/pjreddie/backup/";
  11. int classes = 10;
  12. int N = 50000;
  13. char **labels = get_labels("data/cifar/labels.txt");
  14. int epoch = (*net->seen)/N;
  15. data train = load_all_cifar10();
  16. while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
  17. clock_t time=clock();
  18. float loss = train_network_sgd(net, train, 1);
  19. if(avg_loss == -1) avg_loss = loss;
  20. avg_loss = avg_loss*.95 + loss*.05;
  21. printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net->seen);
  22. if(*net->seen/N > epoch){
  23. epoch = *net->seen/N;
  24. char buff[256];
  25. sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
  26. save_weights(net, buff);
  27. }
  28. if(get_current_batch(net)%100 == 0){
  29. char buff[256];
  30. sprintf(buff, "%s/%s.backup",backup_directory,base);
  31. save_weights(net, buff);
  32. }
  33. }
  34. char buff[256];
  35. sprintf(buff, "%s/%s.weights", backup_directory, base);
  36. save_weights(net, buff);
  37. free_network(net);
  38. free_ptrs((void**)labels, classes);
  39. free(base);
  40. free_data(train);
  41. }
  42. void train_cifar_distill(char *cfgfile, char *weightfile)
  43. {
  44. srand(time(0));
  45. float avg_loss = -1;
  46. char *base = basecfg(cfgfile);
  47. printf("%s\n", base);
  48. network *net = load_network(cfgfile, weightfile, 0);
  49. printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
  50. char *backup_directory = "/home/pjreddie/backup/";
  51. int classes = 10;
  52. int N = 50000;
  53. char **labels = get_labels("data/cifar/labels.txt");
  54. int epoch = (*net->seen)/N;
  55. data train = load_all_cifar10();
  56. matrix soft = csv_to_matrix("results/ensemble.csv");
  57. float weight = .9;
  58. scale_matrix(soft, weight);
  59. scale_matrix(train.y, 1. - weight);
  60. matrix_add_matrix(soft, train.y);
  61. while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
  62. clock_t time=clock();
  63. float loss = train_network_sgd(net, train, 1);
  64. if(avg_loss == -1) avg_loss = loss;
  65. avg_loss = avg_loss*.95 + loss*.05;
  66. printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net->seen);
  67. if(*net->seen/N > epoch){
  68. epoch = *net->seen/N;
  69. char buff[256];
  70. sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
  71. save_weights(net, buff);
  72. }
  73. if(get_current_batch(net)%100 == 0){
  74. char buff[256];
  75. sprintf(buff, "%s/%s.backup",backup_directory,base);
  76. save_weights(net, buff);
  77. }
  78. }
  79. char buff[256];
  80. sprintf(buff, "%s/%s.weights", backup_directory, base);
  81. save_weights(net, buff);
  82. free_network(net);
  83. free_ptrs((void**)labels, classes);
  84. free(base);
  85. free_data(train);
  86. }
  87. void test_cifar_multi(char *filename, char *weightfile)
  88. {
  89. network *net = load_network(filename, weightfile, 0);
  90. set_batch_network(net, 1);
  91. srand(time(0));
  92. float avg_acc = 0;
  93. data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
  94. int i;
  95. for(i = 0; i < test.X.rows; ++i){
  96. image im = float_to_image(32, 32, 3, test.X.vals[i]);
  97. float pred[10] = {0};
  98. float *p = network_predict(net, im.data);
  99. axpy_cpu(10, 1, p, 1, pred, 1);
  100. flip_image(im);
  101. p = network_predict(net, im.data);
  102. axpy_cpu(10, 1, p, 1, pred, 1);
  103. int index = max_index(pred, 10);
  104. int class = max_index(test.y.vals[i], 10);
  105. if(index == class) avg_acc += 1;
  106. free_image(im);
  107. printf("%4d: %.2f%%\n", i, 100.*avg_acc/(i+1));
  108. }
  109. }
  110. void test_cifar(char *filename, char *weightfile)
  111. {
  112. network *net = load_network(filename, weightfile, 0);
  113. srand(time(0));
  114. clock_t time;
  115. float avg_acc = 0;
  116. float avg_top5 = 0;
  117. data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
  118. time=clock();
  119. float *acc = network_accuracies(net, test, 2);
  120. avg_acc += acc[0];
  121. avg_top5 += acc[1];
  122. printf("top1: %f, %lf seconds, %d images\n", avg_acc, sec(clock()-time), test.X.rows);
  123. free_data(test);
  124. }
  125. void extract_cifar()
  126. {
  127. char *labels[] = {"airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"};
  128. int i;
  129. data train = load_all_cifar10();
  130. data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
  131. for(i = 0; i < train.X.rows; ++i){
  132. image im = float_to_image(32, 32, 3, train.X.vals[i]);
  133. int class = max_index(train.y.vals[i], 10);
  134. char buff[256];
  135. sprintf(buff, "data/cifar/train/%d_%s",i,labels[class]);
  136. save_image_options(im, buff, PNG, 0);
  137. }
  138. for(i = 0; i < test.X.rows; ++i){
  139. image im = float_to_image(32, 32, 3, test.X.vals[i]);
  140. int class = max_index(test.y.vals[i], 10);
  141. char buff[256];
  142. sprintf(buff, "data/cifar/test/%d_%s",i,labels[class]);
  143. save_image_options(im, buff, PNG, 0);
  144. }
  145. }
  146. void test_cifar_csv(char *filename, char *weightfile)
  147. {
  148. network *net = load_network(filename, weightfile, 0);
  149. srand(time(0));
  150. data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
  151. matrix pred = network_predict_data(net, test);
  152. int i;
  153. for(i = 0; i < test.X.rows; ++i){
  154. image im = float_to_image(32, 32, 3, test.X.vals[i]);
  155. flip_image(im);
  156. }
  157. matrix pred2 = network_predict_data(net, test);
  158. scale_matrix(pred, .5);
  159. scale_matrix(pred2, .5);
  160. matrix_add_matrix(pred2, pred);
  161. matrix_to_csv(pred);
  162. fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1));
  163. free_data(test);
  164. }
  165. void test_cifar_csvtrain(char *cfg, char *weights)
  166. {
  167. network *net = load_network(cfg, weights, 0);
  168. srand(time(0));
  169. data test = load_all_cifar10();
  170. matrix pred = network_predict_data(net, test);
  171. int i;
  172. for(i = 0; i < test.X.rows; ++i){
  173. image im = float_to_image(32, 32, 3, test.X.vals[i]);
  174. flip_image(im);
  175. }
  176. matrix pred2 = network_predict_data(net, test);
  177. scale_matrix(pred, .5);
  178. scale_matrix(pred2, .5);
  179. matrix_add_matrix(pred2, pred);
  180. matrix_to_csv(pred);
  181. fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1));
  182. free_data(test);
  183. }
  184. void eval_cifar_csv()
  185. {
  186. data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
  187. matrix pred = csv_to_matrix("results/combined.csv");
  188. fprintf(stderr, "%d %d\n", pred.rows, pred.cols);
  189. fprintf(stderr, "Accuracy: %f\n", matrix_topk_accuracy(test.y, pred, 1));
  190. free_data(test);
  191. free_matrix(pred);
  192. }
  193. void run_cifar(int argc, char **argv)
  194. {
  195. if(argc < 4){
  196. fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
  197. return;
  198. }
  199. char *cfg = argv[3];
  200. char *weights = (argc > 4) ? argv[4] : 0;
  201. if(0==strcmp(argv[2], "train")) train_cifar(cfg, weights);
  202. else if(0==strcmp(argv[2], "extract")) extract_cifar();
  203. else if(0==strcmp(argv[2], "distill")) train_cifar_distill(cfg, weights);
  204. else if(0==strcmp(argv[2], "test")) test_cifar(cfg, weights);
  205. else if(0==strcmp(argv[2], "multi")) test_cifar_multi(cfg, weights);
  206. else if(0==strcmp(argv[2], "csv")) test_cifar_csv(cfg, weights);
  207. else if(0==strcmp(argv[2], "csvtrain")) test_cifar_csvtrain(cfg, weights);
  208. else if(0==strcmp(argv[2], "eval")) eval_cifar_csv();
  209. }