123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353 |
- #include "darknet.h"
- void fix_data_captcha(data d, int mask)
- {
- matrix labels = d.y;
- int i, j;
- for(i = 0; i < d.y.rows; ++i){
- for(j = 0; j < d.y.cols; j += 2){
- if (mask){
- if(!labels.vals[i][j]){
- labels.vals[i][j] = SECRET_NUM;
- labels.vals[i][j+1] = SECRET_NUM;
- }else if(labels.vals[i][j+1]){
- labels.vals[i][j] = 0;
- }
- } else{
- if (labels.vals[i][j]) {
- labels.vals[i][j+1] = 0;
- } else {
- labels.vals[i][j+1] = 1;
- }
- }
- }
- }
- }
- void train_captcha(char *cfgfile, char *weightfile)
- {
- srand(time(0));
- float avg_loss = -1;
- char *base = basecfg(cfgfile);
- printf("%s\n", base);
- network *net = load_network(cfgfile, weightfile, 0);
- printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
- int imgs = 1024;
- int i = *net->seen/imgs;
- int solved = 1;
- list *plist;
- char **labels = get_labels("/data/captcha/reimgs.labels.list");
- if (solved){
- plist = get_paths("/data/captcha/reimgs.solved.list");
- }else{
- plist = get_paths("/data/captcha/reimgs.raw.list");
- }
- char **paths = (char **)list_to_array(plist);
- printf("%d\n", plist->size);
- clock_t time;
- pthread_t load_thread;
- data train;
- data buffer;
- load_args args = {0};
- args.w = net->w;
- args.h = net->h;
- args.paths = paths;
- args.classes = 26;
- args.n = imgs;
- args.m = plist->size;
- args.labels = labels;
- args.d = &buffer;
- args.type = CLASSIFICATION_DATA;
- load_thread = load_data_in_thread(args);
- while(1){
- ++i;
- time=clock();
- pthread_join(load_thread, 0);
- train = buffer;
- fix_data_captcha(train, solved);
- /*
- image im = float_to_image(256, 256, 3, train.X.vals[114]);
- show_image(im, "training");
- cvWaitKey(0);
- */
- load_thread = load_data_in_thread(args);
- printf("Loaded: %lf seconds\n", sec(clock()-time));
- time=clock();
- float loss = train_network(net, train);
- if(avg_loss == -1) avg_loss = loss;
- avg_loss = avg_loss*.9 + loss*.1;
- printf("%d: %f, %f avg, %lf seconds, %ld images\n", i, loss, avg_loss, sec(clock()-time), *net->seen);
- free_data(train);
- if(i%100==0){
- char buff[256];
- sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
- save_weights(net, buff);
- }
- }
- }
- void test_captcha(char *cfgfile, char *weightfile, char *filename)
- {
- network *net = load_network(cfgfile, weightfile, 0);
- set_batch_network(net, 1);
- srand(2222222);
- int i = 0;
- char **names = get_labels("/data/captcha/reimgs.labels.list");
- char buff[256];
- char *input = buff;
- int indexes[26];
- while(1){
- if(filename){
- strncpy(input, filename, 256);
- }else{
- //printf("Enter Image Path: ");
- //fflush(stdout);
- input = fgets(input, 256, stdin);
- if(!input) return;
- strtok(input, "\n");
- }
- image im = load_image_color(input, net->w, net->h);
- float *X = im.data;
- float *predictions = network_predict(net, X);
- top_predictions(net, 26, indexes);
- //printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
- for(i = 0; i < 26; ++i){
- int index = indexes[i];
- if(i != 0) printf(", ");
- printf("%s %f", names[index], predictions[index]);
- }
- printf("\n");
- fflush(stdout);
- free_image(im);
- if (filename) break;
- }
- }
- void valid_captcha(char *cfgfile, char *weightfile, char *filename)
- {
- char **labels = get_labels("/data/captcha/reimgs.labels.list");
- network *net = load_network(cfgfile, weightfile, 0);
- list *plist = get_paths("/data/captcha/reimgs.fg.list");
- char **paths = (char **)list_to_array(plist);
- int N = plist->size;
- int outputs = net->outputs;
- set_batch_network(net, 1);
- srand(2222222);
- int i, j;
- for(i = 0; i < N; ++i){
- if (i%100 == 0) fprintf(stderr, "%d\n", i);
- image im = load_image_color(paths[i], net->w, net->h);
- float *X = im.data;
- float *predictions = network_predict(net, X);
- //printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
- int truth = -1;
- for(j = 0; j < 13; ++j){
- if (strstr(paths[i], labels[j])) truth = j;
- }
- if (truth == -1){
- fprintf(stderr, "bad: %s\n", paths[i]);
- return;
- }
- printf("%d, ", truth);
- for(j = 0; j < outputs; ++j){
- if (j != 0) printf(", ");
- printf("%f", predictions[j]);
- }
- printf("\n");
- fflush(stdout);
- free_image(im);
- if (filename) break;
- }
- }
- /*
- void train_captcha(char *cfgfile, char *weightfile)
- {
- float avg_loss = -1;
- srand(time(0));
- char *base = basecfg(cfgfile);
- printf("%s\n", base);
- network net = parse_network_cfg(cfgfile);
- if(weightfile){
- load_weights(&net, weightfile);
- }
- printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
- int imgs = 1024;
- int i = net->seen/imgs;
- list *plist = get_paths("/data/captcha/train.auto5");
- char **paths = (char **)list_to_array(plist);
- printf("%d\n", plist->size);
- clock_t time;
- while(1){
- ++i;
- time=clock();
- data train = load_data_captcha(paths, imgs, plist->size, 10, 200, 60);
- translate_data_rows(train, -128);
- scale_data_rows(train, 1./128);
- printf("Loaded: %lf seconds\n", sec(clock()-time));
- time=clock();
- float loss = train_network(net, train);
- net->seen += imgs;
- if(avg_loss == -1) avg_loss = loss;
- avg_loss = avg_loss*.9 + loss*.1;
- printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net->seen);
- free_data(train);
- if(i%10==0){
- char buff[256];
- sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
- save_weights(net, buff);
- }
- }
- }
- void decode_captcha(char *cfgfile, char *weightfile)
- {
- setbuf(stdout, NULL);
- srand(time(0));
- network net = parse_network_cfg(cfgfile);
- set_batch_network(&net, 1);
- if(weightfile){
- load_weights(&net, weightfile);
- }
- char filename[256];
- while(1){
- printf("Enter filename: ");
- fgets(filename, 256, stdin);
- strtok(filename, "\n");
- image im = load_image_color(filename, 300, 57);
- scale_image(im, 1./255.);
- float *X = im.data;
- float *predictions = network_predict(net, X);
- image out = float_to_image(300, 57, 1, predictions);
- show_image(out, "decoded");
- #ifdef OPENCV
- cvWaitKey(0);
- #endif
- free_image(im);
- }
- }
- void encode_captcha(char *cfgfile, char *weightfile)
- {
- float avg_loss = -1;
- srand(time(0));
- char *base = basecfg(cfgfile);
- printf("%s\n", base);
- network net = parse_network_cfg(cfgfile);
- if(weightfile){
- load_weights(&net, weightfile);
- }
- printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
- int imgs = 1024;
- int i = net->seen/imgs;
- list *plist = get_paths("/data/captcha/encode.list");
- char **paths = (char **)list_to_array(plist);
- printf("%d\n", plist->size);
- clock_t time;
- while(1){
- ++i;
- time=clock();
- data train = load_data_captcha_encode(paths, imgs, plist->size, 300, 57);
- scale_data_rows(train, 1./255);
- printf("Loaded: %lf seconds\n", sec(clock()-time));
- time=clock();
- float loss = train_network(net, train);
- net->seen += imgs;
- if(avg_loss == -1) avg_loss = loss;
- avg_loss = avg_loss*.9 + loss*.1;
- printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net->seen);
- free_matrix(train.X);
- if(i%100==0){
- char buff[256];
- sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
- save_weights(net, buff);
- }
- }
- }
- void validate_captcha(char *cfgfile, char *weightfile)
- {
- srand(time(0));
- char *base = basecfg(cfgfile);
- printf("%s\n", base);
- network net = parse_network_cfg(cfgfile);
- if(weightfile){
- load_weights(&net, weightfile);
- }
- int numchars = 37;
- list *plist = get_paths("/data/captcha/solved.hard");
- char **paths = (char **)list_to_array(plist);
- int imgs = plist->size;
- data valid = load_data_captcha(paths, imgs, 0, 10, 200, 60);
- translate_data_rows(valid, -128);
- scale_data_rows(valid, 1./128);
- matrix pred = network_predict_data(net, valid);
- int i, k;
- int correct = 0;
- int total = 0;
- int accuracy = 0;
- for(i = 0; i < imgs; ++i){
- int allcorrect = 1;
- for(k = 0; k < 10; ++k){
- char truth = int_to_alphanum(max_index(valid.y.vals[i]+k*numchars, numchars));
- char prediction = int_to_alphanum(max_index(pred.vals[i]+k*numchars, numchars));
- if (truth != prediction) allcorrect=0;
- if (truth != '.' && truth == prediction) ++correct;
- if (truth != '.' || truth != prediction) ++total;
- }
- accuracy += allcorrect;
- }
- printf("Word Accuracy: %f, Char Accuracy %f\n", (float)accuracy/imgs, (float)correct/total);
- free_data(valid);
- }
- void test_captcha(char *cfgfile, char *weightfile)
- {
- setbuf(stdout, NULL);
- srand(time(0));
- //char *base = basecfg(cfgfile);
- //printf("%s\n", base);
- network net = parse_network_cfg(cfgfile);
- set_batch_network(&net, 1);
- if(weightfile){
- load_weights(&net, weightfile);
- }
- char filename[256];
- while(1){
- //printf("Enter filename: ");
- fgets(filename, 256, stdin);
- strtok(filename, "\n");
- image im = load_image_color(filename, 200, 60);
- translate_image(im, -128);
- scale_image(im, 1/128.);
- float *X = im.data;
- float *predictions = network_predict(net, X);
- print_letters(predictions, 10);
- free_image(im);
- }
- }
- */
- void run_captcha(int argc, char **argv)
- {
- if(argc < 4){
- fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
- return;
- }
- char *cfg = argv[3];
- char *weights = (argc > 4) ? argv[4] : 0;
- char *filename = (argc > 5) ? argv[5]: 0;
- if(0==strcmp(argv[2], "train")) train_captcha(cfg, weights);
- else if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights, filename);
- else if(0==strcmp(argv[2], "valid")) valid_captcha(cfg, weights, filename);
- //if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights);
- //else if(0==strcmp(argv[2], "encode")) encode_captcha(cfg, weights);
- //else if(0==strcmp(argv[2], "decode")) decode_captcha(cfg, weights);
- //else if(0==strcmp(argv[2], "valid")) validate_captcha(cfg, weights);
- }
|