captcha.c 11 KB


  1. #include "darknet.h"
  2. void fix_data_captcha(data d, int mask)
  3. {
  4. matrix labels = d.y;
  5. int i, j;
  6. for(i = 0; i < d.y.rows; ++i){
  7. for(j = 0; j < d.y.cols; j += 2){
  8. if (mask){
  9. if(!labels.vals[i][j]){
  10. labels.vals[i][j] = SECRET_NUM;
  11. labels.vals[i][j+1] = SECRET_NUM;
  12. }else if(labels.vals[i][j+1]){
  13. labels.vals[i][j] = 0;
  14. }
  15. } else{
  16. if (labels.vals[i][j]) {
  17. labels.vals[i][j+1] = 0;
  18. } else {
  19. labels.vals[i][j+1] = 1;
  20. }
  21. }
  22. }
  23. }
  24. }
  25. void train_captcha(char *cfgfile, char *weightfile)
  26. {
  27. srand(time(0));
  28. float avg_loss = -1;
  29. char *base = basecfg(cfgfile);
  30. printf("%s\n", base);
  31. network *net = load_network(cfgfile, weightfile, 0);
  32. printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
  33. int imgs = 1024;
  34. int i = *net->seen/imgs;
  35. int solved = 1;
  36. list *plist;
  37. char **labels = get_labels("/data/captcha/reimgs.labels.list");
  38. if (solved){
  39. plist = get_paths("/data/captcha/reimgs.solved.list");
  40. }else{
  41. plist = get_paths("/data/captcha/reimgs.raw.list");
  42. }
  43. char **paths = (char **)list_to_array(plist);
  44. printf("%d\n", plist->size);
  45. clock_t time;
  46. pthread_t load_thread;
  47. data train;
  48. data buffer;
  49. load_args args = {0};
  50. args.w = net->w;
  51. args.h = net->h;
  52. args.paths = paths;
  53. args.classes = 26;
  54. args.n = imgs;
  55. args.m = plist->size;
  56. args.labels = labels;
  57. args.d = &buffer;
  58. args.type = CLASSIFICATION_DATA;
  59. load_thread = load_data_in_thread(args);
  60. while(1){
  61. ++i;
  62. time=clock();
  63. pthread_join(load_thread, 0);
  64. train = buffer;
  65. fix_data_captcha(train, solved);
  66. /*
  67. image im = float_to_image(256, 256, 3, train.X.vals[114]);
  68. show_image(im, "training");
  69. cvWaitKey(0);
  70. */
  71. load_thread = load_data_in_thread(args);
  72. printf("Loaded: %lf seconds\n", sec(clock()-time));
  73. time=clock();
  74. float loss = train_network(net, train);
  75. if(avg_loss == -1) avg_loss = loss;
  76. avg_loss = avg_loss*.9 + loss*.1;
  77. printf("%d: %f, %f avg, %lf seconds, %ld images\n", i, loss, avg_loss, sec(clock()-time), *net->seen);
  78. free_data(train);
  79. if(i%100==0){
  80. char buff[256];
  81. sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
  82. save_weights(net, buff);
  83. }
  84. }
  85. }
  86. void test_captcha(char *cfgfile, char *weightfile, char *filename)
  87. {
  88. network *net = load_network(cfgfile, weightfile, 0);
  89. set_batch_network(net, 1);
  90. srand(2222222);
  91. int i = 0;
  92. char **names = get_labels("/data/captcha/reimgs.labels.list");
  93. char buff[256];
  94. char *input = buff;
  95. int indexes[26];
  96. while(1){
  97. if(filename){
  98. strncpy(input, filename, 256);
  99. }else{
  100. //printf("Enter Image Path: ");
  101. //fflush(stdout);
  102. input = fgets(input, 256, stdin);
  103. if(!input) return;
  104. strtok(input, "\n");
  105. }
  106. image im = load_image_color(input, net->w, net->h);
  107. float *X = im.data;
  108. float *predictions = network_predict(net, X);
  109. top_predictions(net, 26, indexes);
  110. //printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
  111. for(i = 0; i < 26; ++i){
  112. int index = indexes[i];
  113. if(i != 0) printf(", ");
  114. printf("%s %f", names[index], predictions[index]);
  115. }
  116. printf("\n");
  117. fflush(stdout);
  118. free_image(im);
  119. if (filename) break;
  120. }
  121. }
  122. void valid_captcha(char *cfgfile, char *weightfile, char *filename)
  123. {
  124. char **labels = get_labels("/data/captcha/reimgs.labels.list");
  125. network *net = load_network(cfgfile, weightfile, 0);
  126. list *plist = get_paths("/data/captcha/reimgs.fg.list");
  127. char **paths = (char **)list_to_array(plist);
  128. int N = plist->size;
  129. int outputs = net->outputs;
  130. set_batch_network(net, 1);
  131. srand(2222222);
  132. int i, j;
  133. for(i = 0; i < N; ++i){
  134. if (i%100 == 0) fprintf(stderr, "%d\n", i);
  135. image im = load_image_color(paths[i], net->w, net->h);
  136. float *X = im.data;
  137. float *predictions = network_predict(net, X);
  138. //printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
  139. int truth = -1;
  140. for(j = 0; j < 13; ++j){
  141. if (strstr(paths[i], labels[j])) truth = j;
  142. }
  143. if (truth == -1){
  144. fprintf(stderr, "bad: %s\n", paths[i]);
  145. return;
  146. }
  147. printf("%d, ", truth);
  148. for(j = 0; j < outputs; ++j){
  149. if (j != 0) printf(", ");
  150. printf("%f", predictions[j]);
  151. }
  152. printf("\n");
  153. fflush(stdout);
  154. free_image(im);
  155. if (filename) break;
  156. }
  157. }
  158. /*
  159. void train_captcha(char *cfgfile, char *weightfile)
  160. {
  161. float avg_loss = -1;
  162. srand(time(0));
  163. char *base = basecfg(cfgfile);
  164. printf("%s\n", base);
  165. network net = parse_network_cfg(cfgfile);
  166. if(weightfile){
  167. load_weights(&net, weightfile);
  168. }
  169. printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
  170. int imgs = 1024;
  171. int i = net->seen/imgs;
  172. list *plist = get_paths("/data/captcha/train.auto5");
  173. char **paths = (char **)list_to_array(plist);
  174. printf("%d\n", plist->size);
  175. clock_t time;
  176. while(1){
  177. ++i;
  178. time=clock();
  179. data train = load_data_captcha(paths, imgs, plist->size, 10, 200, 60);
  180. translate_data_rows(train, -128);
  181. scale_data_rows(train, 1./128);
  182. printf("Loaded: %lf seconds\n", sec(clock()-time));
  183. time=clock();
  184. float loss = train_network(net, train);
  185. net->seen += imgs;
  186. if(avg_loss == -1) avg_loss = loss;
  187. avg_loss = avg_loss*.9 + loss*.1;
  188. printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net->seen);
  189. free_data(train);
  190. if(i%10==0){
  191. char buff[256];
  192. sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
  193. save_weights(net, buff);
  194. }
  195. }
  196. }
  197. void decode_captcha(char *cfgfile, char *weightfile)
  198. {
  199. setbuf(stdout, NULL);
  200. srand(time(0));
  201. network net = parse_network_cfg(cfgfile);
  202. set_batch_network(&net, 1);
  203. if(weightfile){
  204. load_weights(&net, weightfile);
  205. }
  206. char filename[256];
  207. while(1){
  208. printf("Enter filename: ");
  209. fgets(filename, 256, stdin);
  210. strtok(filename, "\n");
  211. image im = load_image_color(filename, 300, 57);
  212. scale_image(im, 1./255.);
  213. float *X = im.data;
  214. float *predictions = network_predict(net, X);
  215. image out = float_to_image(300, 57, 1, predictions);
  216. show_image(out, "decoded");
  217. #ifdef OPENCV
  218. cvWaitKey(0);
  219. #endif
  220. free_image(im);
  221. }
  222. }
  223. void encode_captcha(char *cfgfile, char *weightfile)
  224. {
  225. float avg_loss = -1;
  226. srand(time(0));
  227. char *base = basecfg(cfgfile);
  228. printf("%s\n", base);
  229. network net = parse_network_cfg(cfgfile);
  230. if(weightfile){
  231. load_weights(&net, weightfile);
  232. }
  233. printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
  234. int imgs = 1024;
  235. int i = net->seen/imgs;
  236. list *plist = get_paths("/data/captcha/encode.list");
  237. char **paths = (char **)list_to_array(plist);
  238. printf("%d\n", plist->size);
  239. clock_t time;
  240. while(1){
  241. ++i;
  242. time=clock();
  243. data train = load_data_captcha_encode(paths, imgs, plist->size, 300, 57);
  244. scale_data_rows(train, 1./255);
  245. printf("Loaded: %lf seconds\n", sec(clock()-time));
  246. time=clock();
  247. float loss = train_network(net, train);
  248. net->seen += imgs;
  249. if(avg_loss == -1) avg_loss = loss;
  250. avg_loss = avg_loss*.9 + loss*.1;
  251. printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net->seen);
  252. free_matrix(train.X);
  253. if(i%100==0){
  254. char buff[256];
  255. sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
  256. save_weights(net, buff);
  257. }
  258. }
  259. }
  260. void validate_captcha(char *cfgfile, char *weightfile)
  261. {
  262. srand(time(0));
  263. char *base = basecfg(cfgfile);
  264. printf("%s\n", base);
  265. network net = parse_network_cfg(cfgfile);
  266. if(weightfile){
  267. load_weights(&net, weightfile);
  268. }
  269. int numchars = 37;
  270. list *plist = get_paths("/data/captcha/solved.hard");
  271. char **paths = (char **)list_to_array(plist);
  272. int imgs = plist->size;
  273. data valid = load_data_captcha(paths, imgs, 0, 10, 200, 60);
  274. translate_data_rows(valid, -128);
  275. scale_data_rows(valid, 1./128);
  276. matrix pred = network_predict_data(net, valid);
  277. int i, k;
  278. int correct = 0;
  279. int total = 0;
  280. int accuracy = 0;
  281. for(i = 0; i < imgs; ++i){
  282. int allcorrect = 1;
  283. for(k = 0; k < 10; ++k){
  284. char truth = int_to_alphanum(max_index(valid.y.vals[i]+k*numchars, numchars));
  285. char prediction = int_to_alphanum(max_index(pred.vals[i]+k*numchars, numchars));
  286. if (truth != prediction) allcorrect=0;
  287. if (truth != '.' && truth == prediction) ++correct;
  288. if (truth != '.' || truth != prediction) ++total;
  289. }
  290. accuracy += allcorrect;
  291. }
  292. printf("Word Accuracy: %f, Char Accuracy %f\n", (float)accuracy/imgs, (float)correct/total);
  293. free_data(valid);
  294. }
  295. void test_captcha(char *cfgfile, char *weightfile)
  296. {
  297. setbuf(stdout, NULL);
  298. srand(time(0));
  299. //char *base = basecfg(cfgfile);
  300. //printf("%s\n", base);
  301. network net = parse_network_cfg(cfgfile);
  302. set_batch_network(&net, 1);
  303. if(weightfile){
  304. load_weights(&net, weightfile);
  305. }
  306. char filename[256];
  307. while(1){
  308. //printf("Enter filename: ");
  309. fgets(filename, 256, stdin);
  310. strtok(filename, "\n");
  311. image im = load_image_color(filename, 200, 60);
  312. translate_image(im, -128);
  313. scale_image(im, 1/128.);
  314. float *X = im.data;
  315. float *predictions = network_predict(net, X);
  316. print_letters(predictions, 10);
  317. free_image(im);
  318. }
  319. }
  320. */
  321. void run_captcha(int argc, char **argv)
  322. {
  323. if(argc < 4){
  324. fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
  325. return;
  326. }
  327. char *cfg = argv[3];
  328. char *weights = (argc > 4) ? argv[4] : 0;
  329. char *filename = (argc > 5) ? argv[5]: 0;
  330. if(0==strcmp(argv[2], "train")) train_captcha(cfg, weights);
  331. else if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights, filename);
  332. else if(0==strcmp(argv[2], "valid")) valid_captcha(cfg, weights, filename);
  333. //if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights);
  334. //else if(0==strcmp(argv[2], "encode")) encode_captcha(cfg, weights);
  335. //else if(0==strcmp(argv[2], "decode")) decode_captcha(cfg, weights);
  336. //else if(0==strcmp(argv[2], "valid")) validate_captcha(cfg, weights);
  337. }