classifier.c 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098
  1. #include "darknet.h"
  2. #include <sys/time.h>
  3. #include <assert.h>
  4. float *get_regression_values(char **labels, int n)
  5. {
  6. float *v = calloc(n, sizeof(float));
  7. int i;
  8. for(i = 0; i < n; ++i){
  9. char *p = strchr(labels[i], ' ');
  10. *p = 0;
  11. v[i] = atof(p+1);
  12. }
  13. return v;
  14. }
  15. void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
  16. {
  17. int i;
  18. float avg_loss = -1;
  19. char *base = basecfg(cfgfile);
  20. printf("%s\n", base);
  21. printf("%d\n", ngpus);
  22. network **nets = calloc(ngpus, sizeof(network*));
  23. srand(time(0));
  24. int seed = rand();
  25. for(i = 0; i < ngpus; ++i){
  26. srand(seed);
  27. #ifdef GPU
  28. cuda_set_device(gpus[i]);
  29. #endif
  30. nets[i] = load_network(cfgfile, weightfile, clear);
  31. nets[i]->learning_rate *= ngpus;
  32. }
  33. srand(time(0));
  34. network *net = nets[0];
  35. int imgs = net->batch * net->subdivisions * ngpus;
  36. printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
  37. list *options = read_data_cfg(datacfg);
  38. char *backup_directory = option_find_str(options, "backup", "/backup/");
  39. int tag = option_find_int_quiet(options, "tag", 0);
  40. char *label_list = option_find_str(options, "labels", "data/labels.list");
  41. char *train_list = option_find_str(options, "train", "data/train.list");
  42. char *tree = option_find_str(options, "tree", 0);
  43. if (tree) net->hierarchy = read_tree(tree);
  44. int classes = option_find_int(options, "classes", 2);
  45. char **labels = 0;
  46. if(!tag){
  47. labels = get_labels(label_list);
  48. }
  49. list *plist = get_paths(train_list);
  50. char **paths = (char **)list_to_array(plist);
  51. printf("%d\n", plist->size);
  52. int N = plist->size;
  53. double time;
  54. load_args args = {0};
  55. args.w = net->w;
  56. args.h = net->h;
  57. args.threads = 32;
  58. args.hierarchy = net->hierarchy;
  59. args.min = net->min_ratio*net->w;
  60. args.max = net->max_ratio*net->w;
  61. printf("%d %d\n", args.min, args.max);
  62. args.angle = net->angle;
  63. args.aspect = net->aspect;
  64. args.exposure = net->exposure;
  65. args.saturation = net->saturation;
  66. args.hue = net->hue;
  67. args.size = net->w;
  68. args.paths = paths;
  69. args.classes = classes;
  70. args.n = imgs;
  71. args.m = N;
  72. args.labels = labels;
  73. if (tag){
  74. args.type = TAG_DATA;
  75. } else {
  76. args.type = CLASSIFICATION_DATA;
  77. }
  78. data train;
  79. data buffer;
  80. pthread_t load_thread;
  81. args.d = &buffer;
  82. load_thread = load_data(args);
  83. int count = 0;
  84. int epoch = (*net->seen)/N;
  85. while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
  86. if(net->random && count++%40 == 0){
  87. printf("Resizing\n");
  88. int dim = (rand() % 11 + 4) * 32;
  89. //if (get_current_batch(net)+200 > net->max_batches) dim = 608;
  90. //int dim = (rand() % 4 + 16) * 32;
  91. printf("%d\n", dim);
  92. args.w = dim;
  93. args.h = dim;
  94. args.size = dim;
  95. args.min = net->min_ratio*dim;
  96. args.max = net->max_ratio*dim;
  97. printf("%d %d\n", args.min, args.max);
  98. pthread_join(load_thread, 0);
  99. train = buffer;
  100. free_data(train);
  101. load_thread = load_data(args);
  102. for(i = 0; i < ngpus; ++i){
  103. resize_network(nets[i], dim, dim);
  104. }
  105. net = nets[0];
  106. }
  107. time = what_time_is_it_now();
  108. pthread_join(load_thread, 0);
  109. train = buffer;
  110. load_thread = load_data(args);
  111. printf("Loaded: %lf seconds\n", what_time_is_it_now()-time);
  112. time = what_time_is_it_now();
  113. float loss = 0;
  114. #ifdef GPU
  115. if(ngpus == 1){
  116. loss = train_network(net, train);
  117. } else {
  118. loss = train_networks(nets, ngpus, train, 4);
  119. }
  120. #else
  121. loss = train_network(net, train);
  122. #endif
  123. if(avg_loss == -1) avg_loss = loss;
  124. avg_loss = avg_loss*.9 + loss*.1;
  125. printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), what_time_is_it_now()-time, *net->seen);
  126. free_data(train);
  127. if(*net->seen/N > epoch){
  128. epoch = *net->seen/N;
  129. char buff[256];
  130. sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
  131. save_weights(net, buff);
  132. }
  133. if(get_current_batch(net)%1000 == 0){
  134. char buff[256];
  135. sprintf(buff, "%s/%s.backup",backup_directory,base);
  136. save_weights(net, buff);
  137. }
  138. }
  139. char buff[256];
  140. sprintf(buff, "%s/%s.weights", backup_directory, base);
  141. save_weights(net, buff);
  142. pthread_join(load_thread, 0);
  143. free_network(net);
  144. if(labels) free_ptrs((void**)labels, classes);
  145. free_ptrs((void**)paths, plist->size);
  146. free_list(plist);
  147. free(base);
  148. }
  149. void validate_classifier_crop(char *datacfg, char *filename, char *weightfile)
  150. {
  151. int i = 0;
  152. network *net = load_network(filename, weightfile, 0);
  153. srand(time(0));
  154. list *options = read_data_cfg(datacfg);
  155. char *label_list = option_find_str(options, "labels", "data/labels.list");
  156. char *valid_list = option_find_str(options, "valid", "data/train.list");
  157. int classes = option_find_int(options, "classes", 2);
  158. int topk = option_find_int(options, "top", 1);
  159. char **labels = get_labels(label_list);
  160. list *plist = get_paths(valid_list);
  161. char **paths = (char **)list_to_array(plist);
  162. int m = plist->size;
  163. free_list(plist);
  164. clock_t time;
  165. float avg_acc = 0;
  166. float avg_topk = 0;
  167. int splits = m/1000;
  168. int num = (i+1)*m/splits - i*m/splits;
  169. data val, buffer;
  170. load_args args = {0};
  171. args.w = net->w;
  172. args.h = net->h;
  173. args.paths = paths;
  174. args.classes = classes;
  175. args.n = num;
  176. args.m = 0;
  177. args.labels = labels;
  178. args.d = &buffer;
  179. args.type = OLD_CLASSIFICATION_DATA;
  180. pthread_t load_thread = load_data_in_thread(args);
  181. for(i = 1; i <= splits; ++i){
  182. time=clock();
  183. pthread_join(load_thread, 0);
  184. val = buffer;
  185. num = (i+1)*m/splits - i*m/splits;
  186. char **part = paths+(i*m/splits);
  187. if(i != splits){
  188. args.paths = part;
  189. load_thread = load_data_in_thread(args);
  190. }
  191. printf("Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time));
  192. time=clock();
  193. float *acc = network_accuracies(net, val, topk);
  194. avg_acc += acc[0];
  195. avg_topk += acc[1];
  196. printf("%d: top 1: %f, top %d: %f, %lf seconds, %d images\n", i, avg_acc/i, topk, avg_topk/i, sec(clock()-time), val.X.rows);
  197. free_data(val);
  198. }
  199. }
  200. void validate_classifier_10(char *datacfg, char *filename, char *weightfile)
  201. {
  202. int i, j;
  203. network *net = load_network(filename, weightfile, 0);
  204. set_batch_network(net, 1);
  205. srand(time(0));
  206. list *options = read_data_cfg(datacfg);
  207. char *label_list = option_find_str(options, "labels", "data/labels.list");
  208. char *valid_list = option_find_str(options, "valid", "data/train.list");
  209. int classes = option_find_int(options, "classes", 2);
  210. int topk = option_find_int(options, "top", 1);
  211. char **labels = get_labels(label_list);
  212. list *plist = get_paths(valid_list);
  213. char **paths = (char **)list_to_array(plist);
  214. int m = plist->size;
  215. free_list(plist);
  216. float avg_acc = 0;
  217. float avg_topk = 0;
  218. int *indexes = calloc(topk, sizeof(int));
  219. for(i = 0; i < m; ++i){
  220. int class = -1;
  221. char *path = paths[i];
  222. for(j = 0; j < classes; ++j){
  223. if(strstr(path, labels[j])){
  224. class = j;
  225. break;
  226. }
  227. }
  228. int w = net->w;
  229. int h = net->h;
  230. int shift = 32;
  231. image im = load_image_color(paths[i], w+shift, h+shift);
  232. image images[10];
  233. images[0] = crop_image(im, -shift, -shift, w, h);
  234. images[1] = crop_image(im, shift, -shift, w, h);
  235. images[2] = crop_image(im, 0, 0, w, h);
  236. images[3] = crop_image(im, -shift, shift, w, h);
  237. images[4] = crop_image(im, shift, shift, w, h);
  238. flip_image(im);
  239. images[5] = crop_image(im, -shift, -shift, w, h);
  240. images[6] = crop_image(im, shift, -shift, w, h);
  241. images[7] = crop_image(im, 0, 0, w, h);
  242. images[8] = crop_image(im, -shift, shift, w, h);
  243. images[9] = crop_image(im, shift, shift, w, h);
  244. float *pred = calloc(classes, sizeof(float));
  245. for(j = 0; j < 10; ++j){
  246. float *p = network_predict(net, images[j].data);
  247. if(net->hierarchy) hierarchy_predictions(p, net->outputs, net->hierarchy, 1, 1);
  248. axpy_cpu(classes, 1, p, 1, pred, 1);
  249. free_image(images[j]);
  250. }
  251. free_image(im);
  252. top_k(pred, classes, topk, indexes);
  253. free(pred);
  254. if(indexes[0] == class) avg_acc += 1;
  255. for(j = 0; j < topk; ++j){
  256. if(indexes[j] == class) avg_topk += 1;
  257. }
  258. printf("%d: top 1: %f, top %d: %f\n", i, avg_acc/(i+1), topk, avg_topk/(i+1));
  259. }
  260. }
  261. void validate_classifier_full(char *datacfg, char *filename, char *weightfile)
  262. {
  263. int i, j;
  264. network *net = load_network(filename, weightfile, 0);
  265. set_batch_network(net, 1);
  266. srand(time(0));
  267. list *options = read_data_cfg(datacfg);
  268. char *label_list = option_find_str(options, "labels", "data/labels.list");
  269. char *valid_list = option_find_str(options, "valid", "data/train.list");
  270. int classes = option_find_int(options, "classes", 2);
  271. int topk = option_find_int(options, "top", 1);
  272. char **labels = get_labels(label_list);
  273. list *plist = get_paths(valid_list);
  274. char **paths = (char **)list_to_array(plist);
  275. int m = plist->size;
  276. free_list(plist);
  277. float avg_acc = 0;
  278. float avg_topk = 0;
  279. int *indexes = calloc(topk, sizeof(int));
  280. int size = net->w;
  281. for(i = 0; i < m; ++i){
  282. int class = -1;
  283. char *path = paths[i];
  284. for(j = 0; j < classes; ++j){
  285. if(strstr(path, labels[j])){
  286. class = j;
  287. break;
  288. }
  289. }
  290. image im = load_image_color(paths[i], 0, 0);
  291. image resized = resize_min(im, size);
  292. resize_network(net, resized.w, resized.h);
  293. //show_image(im, "orig");
  294. //show_image(crop, "cropped");
  295. //cvWaitKey(0);
  296. float *pred = network_predict(net, resized.data);
  297. if(net->hierarchy) hierarchy_predictions(pred, net->outputs, net->hierarchy, 1, 1);
  298. free_image(im);
  299. free_image(resized);
  300. top_k(pred, classes, topk, indexes);
  301. if(indexes[0] == class) avg_acc += 1;
  302. for(j = 0; j < topk; ++j){
  303. if(indexes[j] == class) avg_topk += 1;
  304. }
  305. printf("%d: top 1: %f, top %d: %f\n", i, avg_acc/(i+1), topk, avg_topk/(i+1));
  306. }
  307. }
  308. void validate_classifier_single(char *datacfg, char *filename, char *weightfile)
  309. {
  310. int i, j;
  311. network *net = load_network(filename, weightfile, 0);
  312. set_batch_network(net, 1);
  313. srand(time(0));
  314. list *options = read_data_cfg(datacfg);
  315. char *label_list = option_find_str(options, "labels", "data/labels.list");
  316. char *leaf_list = option_find_str(options, "leaves", 0);
  317. if(leaf_list) change_leaves(net->hierarchy, leaf_list);
  318. char *valid_list = option_find_str(options, "valid", "data/train.list");
  319. int classes = option_find_int(options, "classes", 2);
  320. int topk = option_find_int(options, "top", 1);
  321. char **labels = get_labels(label_list);
  322. list *plist = get_paths(valid_list);
  323. char **paths = (char **)list_to_array(plist);
  324. int m = plist->size;
  325. free_list(plist);
  326. float avg_acc = 0;
  327. float avg_topk = 0;
  328. int *indexes = calloc(topk, sizeof(int));
  329. for(i = 0; i < m; ++i){
  330. int class = -1;
  331. char *path = paths[i];
  332. for(j = 0; j < classes; ++j){
  333. if(strstr(path, labels[j])){
  334. class = j;
  335. break;
  336. }
  337. }
  338. image im = load_image_color(paths[i], 0, 0);
  339. image crop = center_crop_image(im, net->w, net->h);
  340. //grayscale_image_3c(crop);
  341. //show_image(im, "orig");
  342. //show_image(crop, "cropped");
  343. //cvWaitKey(0);
  344. float *pred = network_predict(net, crop.data);
  345. if(net->hierarchy) hierarchy_predictions(pred, net->outputs, net->hierarchy, 1, 1);
  346. free_image(im);
  347. free_image(crop);
  348. top_k(pred, classes, topk, indexes);
  349. if(indexes[0] == class) avg_acc += 1;
  350. for(j = 0; j < topk; ++j){
  351. if(indexes[j] == class) avg_topk += 1;
  352. }
  353. printf("%s, %d, %f, %f, \n", paths[i], class, pred[0], pred[1]);
  354. printf("%d: top 1: %f, top %d: %f\n", i, avg_acc/(i+1), topk, avg_topk/(i+1));
  355. }
  356. }
  357. void validate_classifier_multi(char *datacfg, char *cfg, char *weights)
  358. {
  359. int i, j;
  360. network *net = load_network(cfg, weights, 0);
  361. set_batch_network(net, 1);
  362. srand(time(0));
  363. list *options = read_data_cfg(datacfg);
  364. char *label_list = option_find_str(options, "labels", "data/labels.list");
  365. char *valid_list = option_find_str(options, "valid", "data/train.list");
  366. int classes = option_find_int(options, "classes", 2);
  367. int topk = option_find_int(options, "top", 1);
  368. char **labels = get_labels(label_list);
  369. list *plist = get_paths(valid_list);
  370. //int scales[] = {224, 288, 320, 352, 384};
  371. int scales[] = {224, 256, 288, 320};
  372. int nscales = sizeof(scales)/sizeof(scales[0]);
  373. char **paths = (char **)list_to_array(plist);
  374. int m = plist->size;
  375. free_list(plist);
  376. float avg_acc = 0;
  377. float avg_topk = 0;
  378. int *indexes = calloc(topk, sizeof(int));
  379. for(i = 0; i < m; ++i){
  380. int class = -1;
  381. char *path = paths[i];
  382. for(j = 0; j < classes; ++j){
  383. if(strstr(path, labels[j])){
  384. class = j;
  385. break;
  386. }
  387. }
  388. float *pred = calloc(classes, sizeof(float));
  389. image im = load_image_color(paths[i], 0, 0);
  390. for(j = 0; j < nscales; ++j){
  391. image r = resize_max(im, scales[j]);
  392. resize_network(net, r.w, r.h);
  393. float *p = network_predict(net, r.data);
  394. if(net->hierarchy) hierarchy_predictions(p, net->outputs, net->hierarchy, 1 , 1);
  395. axpy_cpu(classes, 1, p, 1, pred, 1);
  396. flip_image(r);
  397. p = network_predict(net, r.data);
  398. axpy_cpu(classes, 1, p, 1, pred, 1);
  399. if(r.data != im.data) free_image(r);
  400. }
  401. free_image(im);
  402. top_k(pred, classes, topk, indexes);
  403. free(pred);
  404. if(indexes[0] == class) avg_acc += 1;
  405. for(j = 0; j < topk; ++j){
  406. if(indexes[j] == class) avg_topk += 1;
  407. }
  408. printf("%d: top 1: %f, top %d: %f\n", i, avg_acc/(i+1), topk, avg_topk/(i+1));
  409. }
  410. }
  411. void try_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename, int layer_num)
  412. {
  413. network *net = load_network(cfgfile, weightfile, 0);
  414. set_batch_network(net, 1);
  415. srand(2222222);
  416. list *options = read_data_cfg(datacfg);
  417. char *name_list = option_find_str(options, "names", 0);
  418. if(!name_list) name_list = option_find_str(options, "labels", "data/labels.list");
  419. int top = option_find_int(options, "top", 1);
  420. int i = 0;
  421. char **names = get_labels(name_list);
  422. clock_t time;
  423. int *indexes = calloc(top, sizeof(int));
  424. char buff[256];
  425. char *input = buff;
  426. while(1){
  427. if(filename){
  428. strncpy(input, filename, 256);
  429. }else{
  430. printf("Enter Image Path: ");
  431. fflush(stdout);
  432. input = fgets(input, 256, stdin);
  433. if(!input) return;
  434. strtok(input, "\n");
  435. }
  436. image orig = load_image_color(input, 0, 0);
  437. image r = resize_min(orig, 256);
  438. image im = crop_image(r, (r.w - 224 - 1)/2 + 1, (r.h - 224 - 1)/2 + 1, 224, 224);
  439. float mean[] = {0.48263312050943, 0.45230225481413, 0.40099074308742};
  440. float std[] = {0.22590347483426, 0.22120921437787, 0.22103996251583};
  441. float var[3];
  442. var[0] = std[0]*std[0];
  443. var[1] = std[1]*std[1];
  444. var[2] = std[2]*std[2];
  445. normalize_cpu(im.data, mean, var, 1, 3, im.w*im.h);
  446. float *X = im.data;
  447. time=clock();
  448. float *predictions = network_predict(net, X);
  449. layer l = net->layers[layer_num];
  450. for(i = 0; i < l.c; ++i){
  451. if(l.rolling_mean) printf("%f %f %f\n", l.rolling_mean[i], l.rolling_variance[i], l.scales[i]);
  452. }
  453. #ifdef GPU
  454. cuda_pull_array(l.output_gpu, l.output, l.outputs);
  455. #endif
  456. for(i = 0; i < l.outputs; ++i){
  457. printf("%f\n", l.output[i]);
  458. }
  459. /*
  460. printf("\n\nWeights\n");
  461. for(i = 0; i < l.n*l.size*l.size*l.c; ++i){
  462. printf("%f\n", l.filters[i]);
  463. }
  464. printf("\n\nBiases\n");
  465. for(i = 0; i < l.n; ++i){
  466. printf("%f\n", l.biases[i]);
  467. }
  468. */
  469. top_predictions(net, top, indexes);
  470. printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
  471. for(i = 0; i < top; ++i){
  472. int index = indexes[i];
  473. printf("%s: %f\n", names[index], predictions[index]);
  474. }
  475. free_image(im);
  476. if (filename) break;
  477. }
  478. }
  479. void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename, int top)
  480. {
  481. network *net = load_network(cfgfile, weightfile, 0);
  482. set_batch_network(net, 1);
  483. srand(2222222);
  484. list *options = read_data_cfg(datacfg);
  485. char *name_list = option_find_str(options, "names", 0);
  486. if(!name_list) name_list = option_find_str(options, "labels", "data/labels.list");
  487. if(top == 0) top = option_find_int(options, "top", 1);
  488. int i = 0;
  489. char **names = get_labels(name_list);
  490. clock_t time;
  491. int *indexes = calloc(top, sizeof(int));
  492. char buff[256];
  493. char *input = buff;
  494. while(1){
  495. if(filename){
  496. strncpy(input, filename, 256);
  497. }else{
  498. printf("Enter Image Path: ");
  499. fflush(stdout);
  500. input = fgets(input, 256, stdin);
  501. if(!input) return;
  502. strtok(input, "\n");
  503. }
  504. image im = load_image_color(input, 0, 0);
  505. image r = letterbox_image(im, net->w, net->h);
  506. //image r = resize_min(im, 320);
  507. //printf("%d %d\n", r.w, r.h);
  508. //resize_network(net, r.w, r.h);
  509. //printf("%d %d\n", r.w, r.h);
  510. float *X = r.data;
  511. time=clock();
  512. float *predictions = network_predict(net, X);
  513. if(net->hierarchy) hierarchy_predictions(predictions, net->outputs, net->hierarchy, 1, 1);
  514. top_k(predictions, net->outputs, top, indexes);
  515. fprintf(stderr, "%s: Predicted in %f seconds.\n", input, sec(clock()-time));
  516. for(i = 0; i < top; ++i){
  517. int index = indexes[i];
  518. //if(net->hierarchy) printf("%d, %s: %f, parent: %s \n",index, names[index], predictions[index], (net->hierarchy->parent[index] >= 0) ? names[net->hierarchy->parent[index]] : "Root");
  519. //else printf("%s: %f\n",names[index], predictions[index]);
  520. printf("%5.2f%%: %s\n", predictions[index]*100, names[index]);
  521. }
  522. if(r.data != im.data) free_image(r);
  523. free_image(im);
  524. if (filename) break;
  525. }
  526. }
  527. void label_classifier(char *datacfg, char *filename, char *weightfile)
  528. {
  529. int i;
  530. network *net = load_network(filename, weightfile, 0);
  531. set_batch_network(net, 1);
  532. srand(time(0));
  533. list *options = read_data_cfg(datacfg);
  534. char *label_list = option_find_str(options, "names", "data/labels.list");
  535. char *test_list = option_find_str(options, "test", "data/train.list");
  536. int classes = option_find_int(options, "classes", 2);
  537. char **labels = get_labels(label_list);
  538. list *plist = get_paths(test_list);
  539. char **paths = (char **)list_to_array(plist);
  540. int m = plist->size;
  541. free_list(plist);
  542. for(i = 0; i < m; ++i){
  543. image im = load_image_color(paths[i], 0, 0);
  544. image resized = resize_min(im, net->w);
  545. image crop = crop_image(resized, (resized.w - net->w)/2, (resized.h - net->h)/2, net->w, net->h);
  546. float *pred = network_predict(net, crop.data);
  547. if(resized.data != im.data) free_image(resized);
  548. free_image(im);
  549. free_image(crop);
  550. int ind = max_index(pred, classes);
  551. printf("%s\n", labels[ind]);
  552. }
  553. }
  554. void csv_classifier(char *datacfg, char *cfgfile, char *weightfile)
  555. {
  556. int i,j;
  557. network *net = load_network(cfgfile, weightfile, 0);
  558. srand(time(0));
  559. list *options = read_data_cfg(datacfg);
  560. char *test_list = option_find_str(options, "test", "data/test.list");
  561. int top = option_find_int(options, "top", 1);
  562. list *plist = get_paths(test_list);
  563. char **paths = (char **)list_to_array(plist);
  564. int m = plist->size;
  565. free_list(plist);
  566. int *indexes = calloc(top, sizeof(int));
  567. for(i = 0; i < m; ++i){
  568. double time = what_time_is_it_now();
  569. char *path = paths[i];
  570. image im = load_image_color(path, 0, 0);
  571. image r = letterbox_image(im, net->w, net->h);
  572. float *predictions = network_predict(net, r.data);
  573. if(net->hierarchy) hierarchy_predictions(predictions, net->outputs, net->hierarchy, 1, 1);
  574. top_k(predictions, net->outputs, top, indexes);
  575. printf("%s", path);
  576. for(j = 0; j < top; ++j){
  577. printf("\t%d", indexes[j]);
  578. }
  579. printf("\n");
  580. free_image(im);
  581. free_image(r);
  582. fprintf(stderr, "%lf seconds, %d images, %d total\n", what_time_is_it_now() - time, i+1, m);
  583. }
  584. }
  585. void test_classifier(char *datacfg, char *cfgfile, char *weightfile, int target_layer)
  586. {
  587. int curr = 0;
  588. network *net = load_network(cfgfile, weightfile, 0);
  589. srand(time(0));
  590. list *options = read_data_cfg(datacfg);
  591. char *test_list = option_find_str(options, "test", "data/test.list");
  592. int classes = option_find_int(options, "classes", 2);
  593. list *plist = get_paths(test_list);
  594. char **paths = (char **)list_to_array(plist);
  595. int m = plist->size;
  596. free_list(plist);
  597. clock_t time;
  598. data val, buffer;
  599. load_args args = {0};
  600. args.w = net->w;
  601. args.h = net->h;
  602. args.paths = paths;
  603. args.classes = classes;
  604. args.n = net->batch;
  605. args.m = 0;
  606. args.labels = 0;
  607. args.d = &buffer;
  608. args.type = OLD_CLASSIFICATION_DATA;
  609. pthread_t load_thread = load_data_in_thread(args);
  610. for(curr = net->batch; curr < m; curr += net->batch){
  611. time=clock();
  612. pthread_join(load_thread, 0);
  613. val = buffer;
  614. if(curr < m){
  615. args.paths = paths + curr;
  616. if (curr + net->batch > m) args.n = m - curr;
  617. load_thread = load_data_in_thread(args);
  618. }
  619. fprintf(stderr, "Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time));
  620. time=clock();
  621. matrix pred = network_predict_data(net, val);
  622. int i, j;
  623. if (target_layer >= 0){
  624. //layer l = net->layers[target_layer];
  625. }
  626. for(i = 0; i < pred.rows; ++i){
  627. printf("%s", paths[curr-net->batch+i]);
  628. for(j = 0; j < pred.cols; ++j){
  629. printf("\t%g", pred.vals[i][j]);
  630. }
  631. printf("\n");
  632. }
  633. free_matrix(pred);
  634. fprintf(stderr, "%lf seconds, %d images, %d total\n", sec(clock()-time), val.X.rows, curr);
  635. free_data(val);
  636. }
  637. }
  638. void file_output_classifier(char *datacfg, char *filename, char *weightfile, char *listfile)
  639. {
  640. int i,j;
  641. network *net = load_network(filename, weightfile, 0);
  642. set_batch_network(net, 1);
  643. srand(time(0));
  644. list *options = read_data_cfg(datacfg);
  645. //char *label_list = option_find_str(options, "names", "data/labels.list");
  646. int classes = option_find_int(options, "classes", 2);
  647. list *plist = get_paths(listfile);
  648. char **paths = (char **)list_to_array(plist);
  649. int m = plist->size;
  650. free_list(plist);
  651. for(i = 0; i < m; ++i){
  652. image im = load_image_color(paths[i], 0, 0);
  653. image resized = resize_min(im, net->w);
  654. image crop = crop_image(resized, (resized.w - net->w)/2, (resized.h - net->h)/2, net->w, net->h);
  655. float *pred = network_predict(net, crop.data);
  656. if(net->hierarchy) hierarchy_predictions(pred, net->outputs, net->hierarchy, 0, 1);
  657. if(resized.data != im.data) free_image(resized);
  658. free_image(im);
  659. free_image(crop);
  660. printf("%s", paths[i]);
  661. for(j = 0; j < classes; ++j){
  662. printf("\t%g", pred[j]);
  663. }
  664. printf("\n");
  665. }
  666. }
  667. void threat_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_index, const char *filename)
  668. {
  669. #ifdef OPENCV
  670. float threat = 0;
  671. float roll = .2;
  672. printf("Classifier Demo\n");
  673. network *net = load_network(cfgfile, weightfile, 0);
  674. set_batch_network(net, 1);
  675. list *options = read_data_cfg(datacfg);
  676. srand(2222222);
  677. void * cap = open_video_stream(filename, cam_index, 0,0,0);
  678. int top = option_find_int(options, "top", 1);
  679. char *name_list = option_find_str(options, "names", 0);
  680. char **names = get_labels(name_list);
  681. int *indexes = calloc(top, sizeof(int));
  682. if(!cap) error("Couldn't connect to webcam.\n");
  683. //cvNamedWindow("Threat", CV_WINDOW_NORMAL);
  684. //cvResizeWindow("Threat", 512, 512);
  685. float fps = 0;
  686. int i;
  687. int count = 0;
  688. while(1){
  689. ++count;
  690. struct timeval tval_before, tval_after, tval_result;
  691. gettimeofday(&tval_before, NULL);
  692. image in = get_image_from_stream(cap);
  693. if(!in.data) break;
  694. image in_s = resize_image(in, net->w, net->h);
  695. image out = in;
  696. int x1 = out.w / 20;
  697. int y1 = out.h / 20;
  698. int x2 = 2*x1;
  699. int y2 = out.h - out.h/20;
  700. int border = .01*out.h;
  701. int h = y2 - y1 - 2*border;
  702. int w = x2 - x1 - 2*border;
  703. float *predictions = network_predict(net, in_s.data);
  704. float curr_threat = 0;
  705. if(1){
  706. curr_threat = predictions[0] * 0 +
  707. predictions[1] * .6 +
  708. predictions[2];
  709. } else {
  710. curr_threat = predictions[218] +
  711. predictions[539] +
  712. predictions[540] +
  713. predictions[368] +
  714. predictions[369] +
  715. predictions[370];
  716. }
  717. threat = roll * curr_threat + (1-roll) * threat;
  718. draw_box_width(out, x2 + border, y1 + .02*h, x2 + .5 * w, y1 + .02*h + border, border, 0,0,0);
  719. if(threat > .97) {
  720. draw_box_width(out, x2 + .5 * w + border,
  721. y1 + .02*h - 2*border,
  722. x2 + .5 * w + 6*border,
  723. y1 + .02*h + 3*border, 3*border, 1,0,0);
  724. }
  725. draw_box_width(out, x2 + .5 * w + border,
  726. y1 + .02*h - 2*border,
  727. x2 + .5 * w + 6*border,
  728. y1 + .02*h + 3*border, .5*border, 0,0,0);
  729. draw_box_width(out, x2 + border, y1 + .42*h, x2 + .5 * w, y1 + .42*h + border, border, 0,0,0);
  730. if(threat > .57) {
  731. draw_box_width(out, x2 + .5 * w + border,
  732. y1 + .42*h - 2*border,
  733. x2 + .5 * w + 6*border,
  734. y1 + .42*h + 3*border, 3*border, 1,1,0);
  735. }
  736. draw_box_width(out, x2 + .5 * w + border,
  737. y1 + .42*h - 2*border,
  738. x2 + .5 * w + 6*border,
  739. y1 + .42*h + 3*border, .5*border, 0,0,0);
  740. draw_box_width(out, x1, y1, x2, y2, border, 0,0,0);
  741. for(i = 0; i < threat * h ; ++i){
  742. float ratio = (float) i / h;
  743. float r = (ratio < .5) ? (2*(ratio)) : 1;
  744. float g = (ratio < .5) ? 1 : 1 - 2*(ratio - .5);
  745. draw_box_width(out, x1 + border, y2 - border - i, x2 - border, y2 - border - i, 1, r, g, 0);
  746. }
  747. top_predictions(net, top, indexes);
  748. char buff[256];
  749. sprintf(buff, "/home/pjreddie/tmp/threat_%06d", count);
  750. //save_image(out, buff);
  751. printf("\033[2J");
  752. printf("\033[1;1H");
  753. printf("\nFPS:%.0f\n",fps);
  754. for(i = 0; i < top; ++i){
  755. int index = indexes[i];
  756. printf("%.1f%%: %s\n", predictions[index]*100, names[index]);
  757. }
  758. if(1){
  759. show_image(out, "Threat", 10);
  760. }
  761. free_image(in_s);
  762. free_image(in);
  763. gettimeofday(&tval_after, NULL);
  764. timersub(&tval_after, &tval_before, &tval_result);
  765. float curr = 1000000.f/((long int)tval_result.tv_usec);
  766. fps = .9*fps + .1*curr;
  767. }
  768. #endif
  769. }
  770. void gun_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_index, const char *filename)
  771. {
  772. #ifdef OPENCV
  773. int bad_cats[] = {218, 539, 540, 1213, 1501, 1742, 1911, 2415, 4348, 19223, 368, 369, 370, 1133, 1200, 1306, 2122, 2301, 2537, 2823, 3179, 3596, 3639, 4489, 5107, 5140, 5289, 6240, 6631, 6762, 7048, 7171, 7969, 7984, 7989, 8824, 8927, 9915, 10270, 10448, 13401, 15205, 18358, 18894, 18895, 19249, 19697};
  774. printf("Classifier Demo\n");
  775. network *net = load_network(cfgfile, weightfile, 0);
  776. set_batch_network(net, 1);
  777. list *options = read_data_cfg(datacfg);
  778. srand(2222222);
  779. void * cap = open_video_stream(filename, cam_index, 0,0,0);
  780. int top = option_find_int(options, "top", 1);
  781. char *name_list = option_find_str(options, "names", 0);
  782. char **names = get_labels(name_list);
  783. int *indexes = calloc(top, sizeof(int));
  784. if(!cap) error("Couldn't connect to webcam.\n");
  785. float fps = 0;
  786. int i;
  787. while(1){
  788. struct timeval tval_before, tval_after, tval_result;
  789. gettimeofday(&tval_before, NULL);
  790. image in = get_image_from_stream(cap);
  791. image in_s = resize_image(in, net->w, net->h);
  792. float *predictions = network_predict(net, in_s.data);
  793. top_predictions(net, top, indexes);
  794. printf("\033[2J");
  795. printf("\033[1;1H");
  796. int threat = 0;
  797. for(i = 0; i < sizeof(bad_cats)/sizeof(bad_cats[0]); ++i){
  798. int index = bad_cats[i];
  799. if(predictions[index] > .01){
  800. printf("Threat Detected!\n");
  801. threat = 1;
  802. break;
  803. }
  804. }
  805. if(!threat) printf("Scanning...\n");
  806. for(i = 0; i < sizeof(bad_cats)/sizeof(bad_cats[0]); ++i){
  807. int index = bad_cats[i];
  808. if(predictions[index] > .01){
  809. printf("%s\n", names[index]);
  810. }
  811. }
  812. show_image(in, "Threat Detection", 10);
  813. free_image(in_s);
  814. free_image(in);
  815. gettimeofday(&tval_after, NULL);
  816. timersub(&tval_after, &tval_before, &tval_result);
  817. float curr = 1000000.f/((long int)tval_result.tv_usec);
  818. fps = .9*fps + .1*curr;
  819. }
  820. #endif
  821. }
  822. void demo_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_index, const char *filename)
  823. {
  824. #ifdef OPENCV
  825. char *base = basecfg(cfgfile);
  826. image **alphabet = load_alphabet();
  827. printf("Classifier Demo\n");
  828. network *net = load_network(cfgfile, weightfile, 0);
  829. set_batch_network(net, 1);
  830. list *options = read_data_cfg(datacfg);
  831. srand(2222222);
  832. int w = 1280;
  833. int h = 720;
  834. void * cap = open_video_stream(filename, cam_index, w, h, 0);
  835. int top = option_find_int(options, "top", 1);
  836. char *label_list = option_find_str(options, "labels", 0);
  837. char *name_list = option_find_str(options, "names", label_list);
  838. char **names = get_labels(name_list);
  839. int *indexes = calloc(top, sizeof(int));
  840. if(!cap) error("Couldn't connect to webcam.\n");
  841. float fps = 0;
  842. int i;
  843. while(1){
  844. struct timeval tval_before, tval_after, tval_result;
  845. gettimeofday(&tval_before, NULL);
  846. image in = get_image_from_stream(cap);
  847. //image in_s = resize_image(in, net->w, net->h);
  848. image in_s = letterbox_image(in, net->w, net->h);
  849. float *predictions = network_predict(net, in_s.data);
  850. if(net->hierarchy) hierarchy_predictions(predictions, net->outputs, net->hierarchy, 1, 1);
  851. top_predictions(net, top, indexes);
  852. printf("\033[2J");
  853. printf("\033[1;1H");
  854. printf("\nFPS:%.0f\n",fps);
  855. int lh = in.h*.03;
  856. int toph = 3*lh;
  857. float rgb[3] = {1,1,1};
  858. for(i = 0; i < top; ++i){
  859. printf("%d\n", toph);
  860. int index = indexes[i];
  861. printf("%.1f%%: %s\n", predictions[index]*100, names[index]);
  862. char buff[1024];
  863. sprintf(buff, "%3.1f%%: %s\n", predictions[index]*100, names[index]);
  864. image label = get_label(alphabet, buff, lh);
  865. draw_label(in, toph, lh, label, rgb);
  866. toph += 2*lh;
  867. free_image(label);
  868. }
  869. show_image(in, base, 10);
  870. free_image(in_s);
  871. free_image(in);
  872. gettimeofday(&tval_after, NULL);
  873. timersub(&tval_after, &tval_before, &tval_result);
  874. float curr = 1000000.f/((long int)tval_result.tv_usec);
  875. fps = .9*fps + .1*curr;
  876. }
  877. #endif
  878. }
  879. void run_classifier(int argc, char **argv)
  880. {
  881. if(argc < 4){
  882. fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
  883. return;
  884. }
  885. char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
  886. int ngpus;
  887. int *gpus = read_intlist(gpu_list, &ngpus, gpu_index);
  888. int cam_index = find_int_arg(argc, argv, "-c", 0);
  889. int top = find_int_arg(argc, argv, "-t", 0);
  890. int clear = find_arg(argc, argv, "-clear");
  891. char *data = argv[3];
  892. char *cfg = argv[4];
  893. char *weights = (argc > 5) ? argv[5] : 0;
  894. char *filename = (argc > 6) ? argv[6]: 0;
  895. char *layer_s = (argc > 7) ? argv[7]: 0;
  896. int layer = layer_s ? atoi(layer_s) : -1;
  897. if(0==strcmp(argv[2], "predict")) predict_classifier(data, cfg, weights, filename, top);
  898. else if(0==strcmp(argv[2], "fout")) file_output_classifier(data, cfg, weights, filename);
  899. else if(0==strcmp(argv[2], "try")) try_classifier(data, cfg, weights, filename, atoi(layer_s));
  900. else if(0==strcmp(argv[2], "train")) train_classifier(data, cfg, weights, gpus, ngpus, clear);
  901. else if(0==strcmp(argv[2], "demo")) demo_classifier(data, cfg, weights, cam_index, filename);
  902. else if(0==strcmp(argv[2], "gun")) gun_classifier(data, cfg, weights, cam_index, filename);
  903. else if(0==strcmp(argv[2], "threat")) threat_classifier(data, cfg, weights, cam_index, filename);
  904. else if(0==strcmp(argv[2], "test")) test_classifier(data, cfg, weights, layer);
  905. else if(0==strcmp(argv[2], "csv")) csv_classifier(data, cfg, weights);
  906. else if(0==strcmp(argv[2], "label")) label_classifier(data, cfg, weights);
  907. else if(0==strcmp(argv[2], "valid")) validate_classifier_single(data, cfg, weights);
  908. else if(0==strcmp(argv[2], "validmulti")) validate_classifier_multi(data, cfg, weights);
  909. else if(0==strcmp(argv[2], "valid10")) validate_classifier_10(data, cfg, weights);
  910. else if(0==strcmp(argv[2], "validcrop")) validate_classifier_crop(data, cfg, weights);
  911. else if(0==strcmp(argv[2], "validfull")) validate_classifier_full(data, cfg, weights);
  912. }