rnn.c 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. #include "darknet.h"
  2. #include <math.h>
  3. typedef struct {
  4. float *x;
  5. float *y;
  6. } float_pair;
  7. unsigned char **load_files(char *filename, int *n)
  8. {
  9. list *paths = get_paths(filename);
  10. *n = paths->size;
  11. unsigned char **contents = calloc(*n, sizeof(char *));
  12. int i;
  13. node *x = paths->front;
  14. for(i = 0; i < *n; ++i){
  15. contents[i] = read_file((char *)x->val);
  16. x = x->next;
  17. }
  18. return contents;
  19. }
  20. int *read_tokenized_data(char *filename, size_t *read)
  21. {
  22. size_t size = 512;
  23. size_t count = 0;
  24. FILE *fp = fopen(filename, "r");
  25. int *d = calloc(size, sizeof(int));
  26. int n, one;
  27. one = fscanf(fp, "%d", &n);
  28. while(one == 1){
  29. ++count;
  30. if(count > size){
  31. size = size*2;
  32. d = realloc(d, size*sizeof(int));
  33. }
  34. d[count-1] = n;
  35. one = fscanf(fp, "%d", &n);
  36. }
  37. fclose(fp);
  38. d = realloc(d, count*sizeof(int));
  39. *read = count;
  40. return d;
  41. }
  42. char **read_tokens(char *filename, size_t *read)
  43. {
  44. size_t size = 512;
  45. size_t count = 0;
  46. FILE *fp = fopen(filename, "r");
  47. char **d = calloc(size, sizeof(char *));
  48. char *line;
  49. while((line=fgetl(fp)) != 0){
  50. ++count;
  51. if(count > size){
  52. size = size*2;
  53. d = realloc(d, size*sizeof(char *));
  54. }
  55. if(0==strcmp(line, "<NEWLINE>")) line = "\n";
  56. d[count-1] = line;
  57. }
  58. fclose(fp);
  59. d = realloc(d, count*sizeof(char *));
  60. *read = count;
  61. return d;
  62. }
  63. float_pair get_rnn_token_data(int *tokens, size_t *offsets, int characters, size_t len, int batch, int steps)
  64. {
  65. float *x = calloc(batch * steps * characters, sizeof(float));
  66. float *y = calloc(batch * steps * characters, sizeof(float));
  67. int i,j;
  68. for(i = 0; i < batch; ++i){
  69. for(j = 0; j < steps; ++j){
  70. int curr = tokens[(offsets[i])%len];
  71. int next = tokens[(offsets[i] + 1)%len];
  72. x[(j*batch + i)*characters + curr] = 1;
  73. y[(j*batch + i)*characters + next] = 1;
  74. offsets[i] = (offsets[i] + 1) % len;
  75. if(curr >= characters || curr < 0 || next >= characters || next < 0){
  76. error("Bad char");
  77. }
  78. }
  79. }
  80. float_pair p;
  81. p.x = x;
  82. p.y = y;
  83. return p;
  84. }
  85. float_pair get_seq2seq_data(char **source, char **dest, int n, int characters, size_t len, int batch, int steps)
  86. {
  87. int i,j;
  88. float *x = calloc(batch * steps * characters, sizeof(float));
  89. float *y = calloc(batch * steps * characters, sizeof(float));
  90. for(i = 0; i < batch; ++i){
  91. int index = rand()%n;
  92. //int slen = strlen(source[index]);
  93. //int dlen = strlen(dest[index]);
  94. for(j = 0; j < steps; ++j){
  95. unsigned char curr = source[index][j];
  96. unsigned char next = dest[index][j];
  97. x[(j*batch + i)*characters + curr] = 1;
  98. y[(j*batch + i)*characters + next] = 1;
  99. if(curr > 255 || curr <= 0 || next > 255 || next <= 0){
  100. /*text[(index+j+2)%len] = 0;
  101. printf("%ld %d %d %d %d\n", index, j, len, (int)text[index+j], (int)text[index+j+1]);
  102. printf("%s", text+index);
  103. */
  104. error("Bad char");
  105. }
  106. }
  107. }
  108. float_pair p;
  109. p.x = x;
  110. p.y = y;
  111. return p;
  112. }
  113. float_pair get_rnn_data(unsigned char *text, size_t *offsets, int characters, size_t len, int batch, int steps)
  114. {
  115. float *x = calloc(batch * steps * characters, sizeof(float));
  116. float *y = calloc(batch * steps * characters, sizeof(float));
  117. int i,j;
  118. for(i = 0; i < batch; ++i){
  119. for(j = 0; j < steps; ++j){
  120. unsigned char curr = text[(offsets[i])%len];
  121. unsigned char next = text[(offsets[i] + 1)%len];
  122. x[(j*batch + i)*characters + curr] = 1;
  123. y[(j*batch + i)*characters + next] = 1;
  124. offsets[i] = (offsets[i] + 1) % len;
  125. if(curr > 255 || curr <= 0 || next > 255 || next <= 0){
  126. /*text[(index+j+2)%len] = 0;
  127. printf("%ld %d %d %d %d\n", index, j, len, (int)text[index+j], (int)text[index+j+1]);
  128. printf("%s", text+index);
  129. */
  130. error("Bad char");
  131. }
  132. }
  133. }
  134. float_pair p;
  135. p.x = x;
  136. p.y = y;
  137. return p;
  138. }
  139. void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, int tokenized)
  140. {
  141. srand(time(0));
  142. unsigned char *text = 0;
  143. int *tokens = 0;
  144. size_t size;
  145. if(tokenized){
  146. tokens = read_tokenized_data(filename, &size);
  147. } else {
  148. text = read_file(filename);
  149. size = strlen((const char*)text);
  150. }
  151. char *backup_directory = "/home/pjreddie/backup/";
  152. char *base = basecfg(cfgfile);
  153. fprintf(stderr, "%s\n", base);
  154. float avg_loss = -1;
  155. network *net = load_network(cfgfile, weightfile, clear);
  156. int inputs = net->inputs;
  157. fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g, Inputs: %d %d %d\n", net->learning_rate, net->momentum, net->decay, inputs, net->batch, net->time_steps);
  158. int batch = net->batch;
  159. int steps = net->time_steps;
  160. if(clear) *net->seen = 0;
  161. int i = (*net->seen)/net->batch;
  162. int streams = batch/steps;
  163. size_t *offsets = calloc(streams, sizeof(size_t));
  164. int j;
  165. for(j = 0; j < streams; ++j){
  166. offsets[j] = rand_size_t()%size;
  167. }
  168. clock_t time;
  169. while(get_current_batch(net) < net->max_batches){
  170. i += 1;
  171. time=clock();
  172. float_pair p;
  173. if(tokenized){
  174. p = get_rnn_token_data(tokens, offsets, inputs, size, streams, steps);
  175. }else{
  176. p = get_rnn_data(text, offsets, inputs, size, streams, steps);
  177. }
  178. copy_cpu(net->inputs*net->batch, p.x, 1, net->input, 1);
  179. copy_cpu(net->truths*net->batch, p.y, 1, net->truth, 1);
  180. float loss = train_network_datum(net) / (batch);
  181. free(p.x);
  182. free(p.y);
  183. if (avg_loss < 0) avg_loss = loss;
  184. avg_loss = avg_loss*.9 + loss*.1;
  185. size_t chars = get_current_batch(net)*batch;
  186. fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds, %f epochs\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), (float) chars/size);
  187. for(j = 0; j < streams; ++j){
  188. //printf("%d\n", j);
  189. if(rand()%64 == 0){
  190. //fprintf(stderr, "Reset\n");
  191. offsets[j] = rand_size_t()%size;
  192. reset_network_state(net, j);
  193. }
  194. }
  195. if(i%10000==0){
  196. char buff[256];
  197. sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
  198. save_weights(net, buff);
  199. }
  200. if(i%100==0){
  201. char buff[256];
  202. sprintf(buff, "%s/%s.backup", backup_directory, base);
  203. save_weights(net, buff);
  204. }
  205. }
  206. char buff[256];
  207. sprintf(buff, "%s/%s_final.weights", backup_directory, base);
  208. save_weights(net, buff);
  209. }
  210. void print_symbol(int n, char **tokens){
  211. if(tokens){
  212. printf("%s ", tokens[n]);
  213. } else {
  214. printf("%c", n);
  215. }
  216. }
  217. void test_char_rnn(char *cfgfile, char *weightfile, int num, char *seed, float temp, int rseed, char *token_file)
  218. {
  219. char **tokens = 0;
  220. if(token_file){
  221. size_t n;
  222. tokens = read_tokens(token_file, &n);
  223. }
  224. srand(rseed);
  225. char *base = basecfg(cfgfile);
  226. fprintf(stderr, "%s\n", base);
  227. network *net = load_network(cfgfile, weightfile, 0);
  228. int inputs = net->inputs;
  229. int i, j;
  230. for(i = 0; i < net->n; ++i) net->layers[i].temperature = temp;
  231. int c = 0;
  232. int len = strlen(seed);
  233. float *input = calloc(inputs, sizeof(float));
  234. /*
  235. fill_cpu(inputs, 0, input, 1);
  236. for(i = 0; i < 10; ++i){
  237. network_predict(net, input);
  238. }
  239. fill_cpu(inputs, 0, input, 1);
  240. */
  241. for(i = 0; i < len-1; ++i){
  242. c = seed[i];
  243. input[c] = 1;
  244. network_predict(net, input);
  245. input[c] = 0;
  246. print_symbol(c, tokens);
  247. }
  248. if(len) c = seed[len-1];
  249. print_symbol(c, tokens);
  250. for(i = 0; i < num; ++i){
  251. input[c] = 1;
  252. float *out = network_predict(net, input);
  253. input[c] = 0;
  254. for(j = 32; j < 127; ++j){
  255. //printf("%d %c %f\n",j, j, out[j]);
  256. }
  257. for(j = 0; j < inputs; ++j){
  258. if (out[j] < .0001) out[j] = 0;
  259. }
  260. c = sample_array(out, inputs);
  261. print_symbol(c, tokens);
  262. }
  263. printf("\n");
  264. }
  265. void test_tactic_rnn_multi(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
  266. {
  267. char **tokens = 0;
  268. if(token_file){
  269. size_t n;
  270. tokens = read_tokens(token_file, &n);
  271. }
  272. srand(rseed);
  273. char *base = basecfg(cfgfile);
  274. fprintf(stderr, "%s\n", base);
  275. network *net = load_network(cfgfile, weightfile, 0);
  276. int inputs = net->inputs;
  277. int i, j;
  278. for(i = 0; i < net->n; ++i) net->layers[i].temperature = temp;
  279. int c = 0;
  280. float *input = calloc(inputs, sizeof(float));
  281. float *out = 0;
  282. while(1){
  283. reset_network_state(net, 0);
  284. while((c = getc(stdin)) != EOF && c != 0){
  285. input[c] = 1;
  286. out = network_predict(net, input);
  287. input[c] = 0;
  288. }
  289. for(i = 0; i < num; ++i){
  290. for(j = 0; j < inputs; ++j){
  291. if (out[j] < .0001) out[j] = 0;
  292. }
  293. int next = sample_array(out, inputs);
  294. if(c == '.' && next == '\n') break;
  295. c = next;
  296. print_symbol(c, tokens);
  297. input[c] = 1;
  298. out = network_predict(net, input);
  299. input[c] = 0;
  300. }
  301. printf("\n");
  302. }
  303. }
  304. void test_tactic_rnn(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
  305. {
  306. char **tokens = 0;
  307. if(token_file){
  308. size_t n;
  309. tokens = read_tokens(token_file, &n);
  310. }
  311. srand(rseed);
  312. char *base = basecfg(cfgfile);
  313. fprintf(stderr, "%s\n", base);
  314. network *net = load_network(cfgfile, weightfile, 0);
  315. int inputs = net->inputs;
  316. int i, j;
  317. for(i = 0; i < net->n; ++i) net->layers[i].temperature = temp;
  318. int c = 0;
  319. float *input = calloc(inputs, sizeof(float));
  320. float *out = 0;
  321. while((c = getc(stdin)) != EOF){
  322. input[c] = 1;
  323. out = network_predict(net, input);
  324. input[c] = 0;
  325. }
  326. for(i = 0; i < num; ++i){
  327. for(j = 0; j < inputs; ++j){
  328. if (out[j] < .0001) out[j] = 0;
  329. }
  330. int next = sample_array(out, inputs);
  331. if(c == '.' && next == '\n') break;
  332. c = next;
  333. print_symbol(c, tokens);
  334. input[c] = 1;
  335. out = network_predict(net, input);
  336. input[c] = 0;
  337. }
  338. printf("\n");
  339. }
  340. void valid_tactic_rnn(char *cfgfile, char *weightfile, char *seed)
  341. {
  342. char *base = basecfg(cfgfile);
  343. fprintf(stderr, "%s\n", base);
  344. network *net = load_network(cfgfile, weightfile, 0);
  345. int inputs = net->inputs;
  346. int count = 0;
  347. int words = 1;
  348. int c;
  349. int len = strlen(seed);
  350. float *input = calloc(inputs, sizeof(float));
  351. int i;
  352. for(i = 0; i < len; ++i){
  353. c = seed[i];
  354. input[(int)c] = 1;
  355. network_predict(net, input);
  356. input[(int)c] = 0;
  357. }
  358. float sum = 0;
  359. c = getc(stdin);
  360. float log2 = log(2);
  361. int in = 0;
  362. while(c != EOF){
  363. int next = getc(stdin);
  364. if(next == EOF) break;
  365. if(next < 0 || next >= 255) error("Out of range character");
  366. input[c] = 1;
  367. float *out = network_predict(net, input);
  368. input[c] = 0;
  369. if(c == '.' && next == '\n') in = 0;
  370. if(!in) {
  371. if(c == '>' && next == '>'){
  372. in = 1;
  373. ++words;
  374. }
  375. c = next;
  376. continue;
  377. }
  378. ++count;
  379. sum += log(out[next])/log2;
  380. c = next;
  381. printf("%d %d Perplexity: %4.4f Word Perplexity: %4.4f\n", count, words, pow(2, -sum/count), pow(2, -sum/words));
  382. }
  383. }
  384. void valid_char_rnn(char *cfgfile, char *weightfile, char *seed)
  385. {
  386. char *base = basecfg(cfgfile);
  387. fprintf(stderr, "%s\n", base);
  388. network *net = load_network(cfgfile, weightfile, 0);
  389. int inputs = net->inputs;
  390. int count = 0;
  391. int words = 1;
  392. int c;
  393. int len = strlen(seed);
  394. float *input = calloc(inputs, sizeof(float));
  395. int i;
  396. for(i = 0; i < len; ++i){
  397. c = seed[i];
  398. input[(int)c] = 1;
  399. network_predict(net, input);
  400. input[(int)c] = 0;
  401. }
  402. float sum = 0;
  403. c = getc(stdin);
  404. float log2 = log(2);
  405. while(c != EOF){
  406. int next = getc(stdin);
  407. if(next == EOF) break;
  408. if(next < 0 || next >= 255) error("Out of range character");
  409. ++count;
  410. if(next == ' ' || next == '\n' || next == '\t') ++words;
  411. input[c] = 1;
  412. float *out = network_predict(net, input);
  413. input[c] = 0;
  414. sum += log(out[next])/log2;
  415. c = next;
  416. printf("%d BPC: %4.4f Perplexity: %4.4f Word Perplexity: %4.4f\n", count, -sum/count, pow(2, -sum/count), pow(2, -sum/words));
  417. }
  418. }
  419. void vec_char_rnn(char *cfgfile, char *weightfile, char *seed)
  420. {
  421. char *base = basecfg(cfgfile);
  422. fprintf(stderr, "%s\n", base);
  423. network *net = load_network(cfgfile, weightfile, 0);
  424. int inputs = net->inputs;
  425. int c;
  426. int seed_len = strlen(seed);
  427. float *input = calloc(inputs, sizeof(float));
  428. int i;
  429. char *line;
  430. while((line=fgetl(stdin)) != 0){
  431. reset_network_state(net, 0);
  432. for(i = 0; i < seed_len; ++i){
  433. c = seed[i];
  434. input[(int)c] = 1;
  435. network_predict(net, input);
  436. input[(int)c] = 0;
  437. }
  438. strip(line);
  439. int str_len = strlen(line);
  440. for(i = 0; i < str_len; ++i){
  441. c = line[i];
  442. input[(int)c] = 1;
  443. network_predict(net, input);
  444. input[(int)c] = 0;
  445. }
  446. c = ' ';
  447. input[(int)c] = 1;
  448. network_predict(net, input);
  449. input[(int)c] = 0;
  450. layer l = net->layers[0];
  451. #ifdef GPU
  452. cuda_pull_array(l.output_gpu, l.output, l.outputs);
  453. #endif
  454. printf("%s", line);
  455. for(i = 0; i < l.outputs; ++i){
  456. printf(",%g", l.output[i]);
  457. }
  458. printf("\n");
  459. }
  460. }
  461. void run_char_rnn(int argc, char **argv)
  462. {
  463. if(argc < 4){
  464. fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
  465. return;
  466. }
  467. char *filename = find_char_arg(argc, argv, "-file", "data/shakespeare.txt");
  468. char *seed = find_char_arg(argc, argv, "-seed", "\n\n");
  469. int len = find_int_arg(argc, argv, "-len", 1000);
  470. float temp = find_float_arg(argc, argv, "-temp", .7);
  471. int rseed = find_int_arg(argc, argv, "-srand", time(0));
  472. int clear = find_arg(argc, argv, "-clear");
  473. int tokenized = find_arg(argc, argv, "-tokenized");
  474. char *tokens = find_char_arg(argc, argv, "-tokens", 0);
  475. char *cfg = argv[3];
  476. char *weights = (argc > 4) ? argv[4] : 0;
  477. if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename, clear, tokenized);
  478. else if(0==strcmp(argv[2], "valid")) valid_char_rnn(cfg, weights, seed);
  479. else if(0==strcmp(argv[2], "validtactic")) valid_tactic_rnn(cfg, weights, seed);
  480. else if(0==strcmp(argv[2], "vec")) vec_char_rnn(cfg, weights, seed);
  481. else if(0==strcmp(argv[2], "generate")) test_char_rnn(cfg, weights, len, seed, temp, rseed, tokens);
  482. else if(0==strcmp(argv[2], "generatetactic")) test_tactic_rnn(cfg, weights, len, temp, rseed, tokens);
  483. }