dice.c 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #include "darknet.h"
  2. char *dice_labels[] = {"face1","face2","face3","face4","face5","face6"};
  3. void train_dice(char *cfgfile, char *weightfile)
  4. {
  5. srand(time(0));
  6. float avg_loss = -1;
  7. char *base = basecfg(cfgfile);
  8. char *backup_directory = "/home/pjreddie/backup/";
  9. printf("%s\n", base);
  10. network net = parse_network_cfg(cfgfile);
  11. if(weightfile){
  12. load_weights(&net, weightfile);
  13. }
  14. printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
  15. int imgs = 1024;
  16. int i = *net.seen/imgs;
  17. char **labels = dice_labels;
  18. list *plist = get_paths("data/dice/dice.train.list");
  19. char **paths = (char **)list_to_array(plist);
  20. printf("%d\n", plist->size);
  21. clock_t time;
  22. while(1){
  23. ++i;
  24. time=clock();
  25. data train = load_data_old(paths, imgs, plist->size, labels, 6, net.w, net.h);
  26. printf("Loaded: %lf seconds\n", sec(clock()-time));
  27. time=clock();
  28. float loss = train_network(net, train);
  29. if(avg_loss == -1) avg_loss = loss;
  30. avg_loss = avg_loss*.9 + loss*.1;
  31. printf("%d: %f, %f avg, %lf seconds, %ld images\n", i, loss, avg_loss, sec(clock()-time), *net.seen);
  32. free_data(train);
  33. if((i % 100) == 0) net.learning_rate *= .1;
  34. if(i%100==0){
  35. char buff[256];
  36. sprintf(buff, "%s/%s_%d.weights",backup_directory,base, i);
  37. save_weights(net, buff);
  38. }
  39. }
  40. }
  41. void validate_dice(char *filename, char *weightfile)
  42. {
  43. network net = parse_network_cfg(filename);
  44. if(weightfile){
  45. load_weights(&net, weightfile);
  46. }
  47. srand(time(0));
  48. char **labels = dice_labels;
  49. list *plist = get_paths("data/dice/dice.val.list");
  50. char **paths = (char **)list_to_array(plist);
  51. int m = plist->size;
  52. free_list(plist);
  53. data val = load_data_old(paths, m, 0, labels, 6, net.w, net.h);
  54. float *acc = network_accuracies(net, val, 2);
  55. printf("Validation Accuracy: %f, %d images\n", acc[0], m);
  56. free_data(val);
  57. }
  58. void test_dice(char *cfgfile, char *weightfile, char *filename)
  59. {
  60. network net = parse_network_cfg(cfgfile);
  61. if(weightfile){
  62. load_weights(&net, weightfile);
  63. }
  64. set_batch_network(&net, 1);
  65. srand(2222222);
  66. int i = 0;
  67. char **names = dice_labels;
  68. char buff[256];
  69. char *input = buff;
  70. int indexes[6];
  71. while(1){
  72. if(filename){
  73. strncpy(input, filename, 256);
  74. }else{
  75. printf("Enter Image Path: ");
  76. fflush(stdout);
  77. input = fgets(input, 256, stdin);
  78. if(!input) return;
  79. strtok(input, "\n");
  80. }
  81. image im = load_image_color(input, net.w, net.h);
  82. float *X = im.data;
  83. float *predictions = network_predict(net, X);
  84. top_predictions(net, 6, indexes);
  85. for(i = 0; i < 6; ++i){
  86. int index = indexes[i];
  87. printf("%s: %f\n", names[index], predictions[index]);
  88. }
  89. free_image(im);
  90. if (filename) break;
  91. }
  92. }
  93. void run_dice(int argc, char **argv)
  94. {
  95. if(argc < 4){
  96. fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
  97. return;
  98. }
  99. char *cfg = argv[3];
  100. char *weights = (argc > 4) ? argv[4] : 0;
  101. char *filename = (argc > 5) ? argv[5]: 0;
  102. if(0==strcmp(argv[2], "test")) test_dice(cfg, weights, filename);
  103. else if(0==strcmp(argv[2], "train")) train_dice(cfg, weights);
  104. else if(0==strcmp(argv[2], "valid")) validate_dice(cfg, weights);
  105. }