go.c 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370
  1. #include "darknet.h"
  2. #include <assert.h>
  3. #include <math.h>
  4. #include <unistd.h>
  5. int inverted = 1;
  6. int noi = 1;
  7. static const int nind = 10;
  8. int legal_go(float *b, float *ko, int p, int r, int c);
  9. int check_ko(float *x, float *ko);
  10. typedef struct {
  11. char **data;
  12. int n;
  13. } moves;
  14. char *fgetgo(FILE *fp)
  15. {
  16. if(feof(fp)) return 0;
  17. size_t size = 96;
  18. char *line = malloc(size*sizeof(char));
  19. if(size != fread(line, sizeof(char), size, fp)){
  20. free(line);
  21. return 0;
  22. }
  23. return line;
  24. }
  25. moves load_go_moves(char *filename)
  26. {
  27. moves m;
  28. m.n = 128;
  29. m.data = calloc(128, sizeof(char*));
  30. FILE *fp = fopen(filename, "rb");
  31. int count = 0;
  32. char *line = 0;
  33. while ((line = fgetgo(fp))) {
  34. if (count >= m.n) {
  35. m.n *= 2;
  36. m.data = realloc(m.data, m.n*sizeof(char*));
  37. }
  38. m.data[count] = line;
  39. ++count;
  40. }
  41. printf("%d\n", count);
  42. m.n = count;
  43. m.data = realloc(m.data, count*sizeof(char*));
  44. return m;
  45. }
  46. void string_to_board(char *s, float *board)
  47. {
  48. int i, j;
  49. memset(board, 0, 2*19*19*sizeof(float));
  50. int count = 0;
  51. for(i = 0; i < 91; ++i){
  52. char c = s[i];
  53. for(j = 0; j < 4; ++j){
  54. int me = (c >> (2*j)) & 1;
  55. int you = (c >> (2*j + 1)) & 1;
  56. if (me) board[count] = 1;
  57. else if (you) board[count + 19*19] = 1;
  58. ++count;
  59. if(count >= 19*19) break;
  60. }
  61. }
  62. }
  63. void board_to_string(char *s, float *board)
  64. {
  65. int i, j;
  66. memset(s, 0, (19*19/4+1)*sizeof(char));
  67. int count = 0;
  68. for(i = 0; i < 91; ++i){
  69. for(j = 0; j < 4; ++j){
  70. int me = (board[count] == 1);
  71. int you = (board[count + 19*19] == 1);
  72. if (me) s[i] = s[i] | (1<<(2*j));
  73. if (you) s[i] = s[i] | (1<<(2*j + 1));
  74. ++count;
  75. if(count >= 19*19) break;
  76. }
  77. }
  78. }
  79. static int occupied(float *b, int i)
  80. {
  81. if (b[i]) return 1;
  82. if (b[i+19*19]) return -1;
  83. return 0;
  84. }
  85. data random_go_moves(moves m, int n)
  86. {
  87. data d = {0};
  88. d.X = make_matrix(n, 19*19*3);
  89. d.y = make_matrix(n, 19*19+2);
  90. int i, j;
  91. for(i = 0; i < n; ++i){
  92. float *board = d.X.vals[i];
  93. float *label = d.y.vals[i];
  94. char *b = m.data[rand()%m.n];
  95. int player = b[0] - '0';
  96. int result = b[1] - '0';
  97. int row = b[2];
  98. int col = b[3];
  99. string_to_board(b+4, board);
  100. if(player > 0) for(j = 0; j < 19*19; ++j) board[19*19*2 + j] = 1;
  101. label[19*19+1] = (player==result);
  102. if(row >= 19 || col >= 19){
  103. label[19*19] = 1;
  104. } else {
  105. label[col + 19*row] = 1;
  106. if(occupied(board, col + 19*row)) printf("hey\n");
  107. }
  108. int flip = rand()%2;
  109. int rotate = rand()%4;
  110. image in = float_to_image(19, 19, 3, board);
  111. image out = float_to_image(19, 19, 1, label);
  112. if(flip){
  113. flip_image(in);
  114. flip_image(out);
  115. }
  116. rotate_image_cw(in, rotate);
  117. rotate_image_cw(out, rotate);
  118. }
  119. return d;
  120. }
  121. void train_go(char *cfgfile, char *weightfile, char *filename, int *gpus, int ngpus, int clear)
  122. {
  123. int i;
  124. float avg_loss = -1;
  125. char *base = basecfg(cfgfile);
  126. printf("%s\n", base);
  127. printf("%d\n", ngpus);
  128. network **nets = calloc(ngpus, sizeof(network*));
  129. srand(time(0));
  130. int seed = rand();
  131. for(i = 0; i < ngpus; ++i){
  132. srand(seed);
  133. #ifdef GPU
  134. cuda_set_device(gpus[i]);
  135. #endif
  136. nets[i] = load_network(cfgfile, weightfile, clear);
  137. nets[i]->learning_rate *= ngpus;
  138. }
  139. network *net = nets[0];
  140. printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
  141. char *backup_directory = "/home/pjreddie/backup/";
  142. char buff[256];
  143. moves m = load_go_moves(filename);
  144. //moves m = load_go_moves("games.txt");
  145. int N = m.n;
  146. printf("Moves: %d\n", N);
  147. int epoch = (*net->seen)/N;
  148. while(get_current_batch(net) < net->max_batches || net->max_batches == 0){
  149. double time=what_time_is_it_now();
  150. data train = random_go_moves(m, net->batch*net->subdivisions*ngpus);
  151. printf("Loaded: %lf seconds\n", what_time_is_it_now() - time);
  152. time=what_time_is_it_now();
  153. float loss = 0;
  154. #ifdef GPU
  155. if(ngpus == 1){
  156. loss = train_network(net, train);
  157. } else {
  158. loss = train_networks(nets, ngpus, train, 10);
  159. }
  160. #else
  161. loss = train_network(net, train);
  162. #endif
  163. free_data(train);
  164. if(avg_loss == -1) avg_loss = loss;
  165. avg_loss = avg_loss*.95 + loss*.05;
  166. printf("%ld, %.3f: %f, %f avg, %f rate, %lf seconds, %ld images\n", get_current_batch(net), (float)(*net->seen)/N, loss, avg_loss, get_current_rate(net), what_time_is_it_now()-time, *net->seen);
  167. if(*net->seen/N > epoch){
  168. epoch = *net->seen/N;
  169. char buff[256];
  170. sprintf(buff, "%s/%s_%d.weights", backup_directory,base, epoch);
  171. save_weights(net, buff);
  172. }
  173. if(get_current_batch(net)%1000 == 0){
  174. char buff[256];
  175. sprintf(buff, "%s/%s.backup",backup_directory,base);
  176. save_weights(net, buff);
  177. }
  178. if(get_current_batch(net)%10000 == 0){
  179. char buff[256];
  180. sprintf(buff, "%s/%s_%ld.backup",backup_directory,base,get_current_batch(net));
  181. save_weights(net, buff);
  182. }
  183. }
  184. sprintf(buff, "%s/%s.weights", backup_directory, base);
  185. save_weights(net, buff);
  186. free_network(net);
  187. free(base);
  188. }
  189. static void propagate_liberty(float *board, int *lib, int *visited, int row, int col, int side)
  190. {
  191. if (row < 0 || row > 18 || col < 0 || col > 18) return;
  192. int index = row*19 + col;
  193. if (occupied(board,index) != side) return;
  194. if (visited[index]) return;
  195. visited[index] = 1;
  196. lib[index] += 1;
  197. propagate_liberty(board, lib, visited, row+1, col, side);
  198. propagate_liberty(board, lib, visited, row-1, col, side);
  199. propagate_liberty(board, lib, visited, row, col+1, side);
  200. propagate_liberty(board, lib, visited, row, col-1, side);
  201. }
  202. static int *calculate_liberties(float *board)
  203. {
  204. int *lib = calloc(19*19, sizeof(int));
  205. int visited[19*19];
  206. int i, j;
  207. for(j = 0; j < 19; ++j){
  208. for(i = 0; i < 19; ++i){
  209. memset(visited, 0, 19*19*sizeof(int));
  210. int index = j*19 + i;
  211. if(!occupied(board,index)){
  212. if ((i > 0) && occupied(board,index - 1)) propagate_liberty(board, lib, visited, j, i-1, occupied(board,index-1));
  213. if ((i < 18) && occupied(board,index + 1)) propagate_liberty(board, lib, visited, j, i+1, occupied(board,index+1));
  214. if ((j > 0) && occupied(board,index - 19)) propagate_liberty(board, lib, visited, j-1, i, occupied(board,index-19));
  215. if ((j < 18) && occupied(board,index + 19)) propagate_liberty(board, lib, visited, j+1, i, occupied(board,index+19));
  216. }
  217. }
  218. }
  219. return lib;
  220. }
  221. void print_board(FILE *stream, float *board, int player, int *indexes)
  222. {
  223. int i,j,n;
  224. fprintf(stream, " ");
  225. for(i = 0; i < 19; ++i){
  226. fprintf(stream, "%c ", 'A' + i + 1*(i > 7 && noi));
  227. }
  228. fprintf(stream, "\n");
  229. for(j = 0; j < 19; ++j){
  230. fprintf(stream, "%2d", (inverted) ? 19-j : j+1);
  231. for(i = 0; i < 19; ++i){
  232. int index = j*19 + i;
  233. if(indexes){
  234. int found = 0;
  235. for(n = 0; n < nind; ++n){
  236. if(index == indexes[n]){
  237. found = 1;
  238. /*
  239. if(n == 0) fprintf(stream, "\uff11");
  240. else if(n == 1) fprintf(stream, "\uff12");
  241. else if(n == 2) fprintf(stream, "\uff13");
  242. else if(n == 3) fprintf(stream, "\uff14");
  243. else if(n == 4) fprintf(stream, "\uff15");
  244. */
  245. fprintf(stream, " %d", n+1);
  246. }
  247. }
  248. if(found) continue;
  249. }
  250. //if(board[index]*-swap > 0) fprintf(stream, "\u25C9 ");
  251. //else if(board[index]*-swap < 0) fprintf(stream, "\u25EF ");
  252. if (occupied(board, index) == player) fprintf(stream, " X");
  253. else if (occupied(board, index) ==-player) fprintf(stream, " O");
  254. else fprintf(stream, " .");
  255. }
  256. fprintf(stream, "\n");
  257. }
  258. }
  259. void flip_board(float *board)
  260. {
  261. int i;
  262. for(i = 0; i < 19*19; ++i){
  263. float swap = board[i];
  264. board[i] = board[i+19*19];
  265. board[i+19*19] = swap;
  266. board[i+19*19*2] = 1-board[i+19*19*2];
  267. }
  268. }
  269. float predict_move2(network *net, float *board, float *move, int multi)
  270. {
  271. float *output = network_predict(net, board);
  272. copy_cpu(19*19+1, output, 1, move, 1);
  273. float result = output[19*19 + 1];
  274. int i;
  275. if(multi){
  276. image bim = float_to_image(19, 19, 3, board);
  277. for(i = 1; i < 8; ++i){
  278. rotate_image_cw(bim, i);
  279. if(i >= 4) flip_image(bim);
  280. float *output = network_predict(net, board);
  281. image oim = float_to_image(19, 19, 1, output);
  282. result += output[19*19 + 1];
  283. if(i >= 4) flip_image(oim);
  284. rotate_image_cw(oim, -i);
  285. axpy_cpu(19*19+1, 1, output, 1, move, 1);
  286. if(i >= 4) flip_image(bim);
  287. rotate_image_cw(bim, -i);
  288. }
  289. result = result/8;
  290. scal_cpu(19*19+1, 1./8., move, 1);
  291. }
  292. for(i = 0; i < 19*19; ++i){
  293. if(board[i] || board[i+19*19]) move[i] = 0;
  294. }
  295. return result;
  296. }
  297. static void remove_connected(float *b, int *lib, int p, int r, int c)
  298. {
  299. if (r < 0 || r >= 19 || c < 0 || c >= 19) return;
  300. if (occupied(b, r*19 + c) != p) return;
  301. if (lib[r*19 + c] != 1) return;
  302. b[r*19 + c] = 0;
  303. b[19*19 + r*19 + c] = 0;
  304. remove_connected(b, lib, p, r+1, c);
  305. remove_connected(b, lib, p, r-1, c);
  306. remove_connected(b, lib, p, r, c+1);
  307. remove_connected(b, lib, p, r, c-1);
  308. }
  309. void move_go(float *b, int p, int r, int c)
  310. {
  311. int *l = calculate_liberties(b);
  312. if(p > 0) b[r*19 + c] = 1;
  313. else b[19*19 + r*19 + c] = 1;
  314. remove_connected(b, l, -p, r+1, c);
  315. remove_connected(b, l, -p, r-1, c);
  316. remove_connected(b, l, -p, r, c+1);
  317. remove_connected(b, l, -p, r, c-1);
  318. free(l);
  319. }
  320. int compare_board(float *a, float *b)
  321. {
  322. if(memcmp(a, b, 19*19*3*sizeof(float)) == 0) return 1;
  323. return 0;
  324. }
  325. typedef struct mcts_tree{
  326. float *board;
  327. struct mcts_tree **children;
  328. float *prior;
  329. int *visit_count;
  330. float *value;
  331. float *mean;
  332. float *prob;
  333. int total_count;
  334. float result;
  335. int done;
  336. int pass;
  337. } mcts_tree;
  338. void free_mcts(mcts_tree *root)
  339. {
  340. if(!root) return;
  341. int i;
  342. free(root->board);
  343. for(i = 0; i < 19*19+1; ++i){
  344. if(root->children[i]) free_mcts(root->children[i]);
  345. }
  346. free(root->children);
  347. free(root->prior);
  348. free(root->visit_count);
  349. free(root->value);
  350. free(root->mean);
  351. free(root->prob);
  352. free(root);
  353. }
  354. float *network_predict_rotations(network *net, float *next)
  355. {
  356. int n = net->batch;
  357. float *in = calloc(19*19*3*n, sizeof(float));
  358. image im = float_to_image(19, 19, 3, next);
  359. int i,j;
  360. int *inds = random_index_order(0, 8);
  361. for(j = 0; j < n; ++j){
  362. i = inds[j];
  363. rotate_image_cw(im, i);
  364. if(i >= 4) flip_image(im);
  365. memcpy(in + 19*19*3*j, im.data, 19*19*3*sizeof(float));
  366. if(i >= 4) flip_image(im);
  367. rotate_image_cw(im, -i);
  368. }
  369. float *pred = network_predict(net, in);
  370. for(j = 0; j < n; ++j){
  371. i = inds[j];
  372. image im = float_to_image(19, 19, 1, pred + j*(19*19 + 2));
  373. if(i >= 4) flip_image(im);
  374. rotate_image_cw(im, -i);
  375. if(j > 0){
  376. axpy_cpu(19*19+2, 1, im.data, 1, pred, 1);
  377. }
  378. }
  379. free(in);
  380. free(inds);
  381. scal_cpu(19*19+2, 1./n, pred, 1);
  382. return pred;
  383. }
  384. mcts_tree *expand(float *next, float *ko, network *net)
  385. {
  386. mcts_tree *root = calloc(1, sizeof(mcts_tree));
  387. root->board = next;
  388. root->children = calloc(19*19+1, sizeof(mcts_tree*));
  389. root->prior = calloc(19*19 + 1, sizeof(float));
  390. root->prob = calloc(19*19 + 1, sizeof(float));
  391. root->mean = calloc(19*19 + 1, sizeof(float));
  392. root->value = calloc(19*19 + 1, sizeof(float));
  393. root->visit_count = calloc(19*19 + 1, sizeof(int));
  394. root->total_count = 1;
  395. int i;
  396. float *pred = network_predict_rotations(net, next);
  397. copy_cpu(19*19+1, pred, 1, root->prior, 1);
  398. float val = 2*pred[19*19 + 1] - 1;
  399. root->result = val;
  400. for(i = 0; i < 19*19+1; ++i) {
  401. root->visit_count[i] = 0;
  402. root->value[i] = 0;
  403. root->mean[i] = val;
  404. if(i < 19*19 && occupied(next, i)){
  405. root->value[i] = -1;
  406. root->mean[i] = -1;
  407. root->prior[i] = 0;
  408. }
  409. }
  410. //print_board(stderr, next, flip?-1:1, 0);
  411. return root;
  412. }
  413. float *copy_board(float *board)
  414. {
  415. float *next = calloc(19*19*3, sizeof(float));
  416. copy_cpu(19*19*3, board, 1, next, 1);
  417. return next;
  418. }
  419. float select_mcts(mcts_tree *root, network *net, float *prev, float cpuct)
  420. {
  421. if(root->done) return -root->result;
  422. int i;
  423. float max = -1000;
  424. int max_i = 0;
  425. for(i = 0; i < 19*19+1; ++i){
  426. root->prob[i] = root->mean[i] + cpuct*root->prior[i] * sqrt(root->total_count) / (1. + root->visit_count[i]);
  427. if(root->prob[i] > max){
  428. max = root->prob[i];
  429. max_i = i;
  430. }
  431. }
  432. float val;
  433. i = max_i;
  434. root->visit_count[i]++;
  435. root->total_count++;
  436. if (root->children[i]) {
  437. val = select_mcts(root->children[i], net, root->board, cpuct);
  438. } else {
  439. if(max_i < 19*19 && !legal_go(root->board, prev, 1, max_i/19, max_i%19)) {
  440. root->mean[i] = -1;
  441. root->value[i] = -1;
  442. root->prior[i] = 0;
  443. --root->total_count;
  444. return select_mcts(root, net, prev, cpuct);
  445. //printf("Detected ko\n");
  446. //getchar();
  447. } else {
  448. float *next = copy_board(root->board);
  449. if (max_i < 19*19) {
  450. move_go(next, 1, max_i / 19, max_i % 19);
  451. }
  452. flip_board(next);
  453. root->children[i] = expand(next, root->board, net);
  454. val = -root->children[i]->result;
  455. if(max_i == 19*19){
  456. root->children[i]->pass = 1;
  457. if (root->pass){
  458. root->children[i]->done = 1;
  459. }
  460. }
  461. }
  462. }
  463. root->value[i] += val;
  464. root->mean[i] = root->value[i]/root->visit_count[i];
  465. return -val;
  466. }
  467. mcts_tree *run_mcts(mcts_tree *tree, network *net, float *board, float *ko, int player, int n, float cpuct, float secs)
  468. {
  469. int i;
  470. double t = what_time_is_it_now();
  471. if(player < 0) flip_board(board);
  472. if(!tree) tree = expand(copy_board(board), ko, net);
  473. assert(compare_board(tree->board, board));
  474. for(i = 0; i < n; ++i){
  475. if (secs > 0 && (what_time_is_it_now() - t) > secs) break;
  476. int max_i = max_int_index(tree->visit_count, 19*19+1);
  477. if (tree->visit_count[max_i] >= n) break;
  478. select_mcts(tree, net, ko, cpuct);
  479. }
  480. if(player < 0) flip_board(board);
  481. //fprintf(stderr, "%f Seconds\n", what_time_is_it_now() - t);
  482. return tree;
  483. }
  484. mcts_tree *move_mcts(mcts_tree *tree, int index)
  485. {
  486. if(index < 0 || index > 19*19 || !tree || !tree->children[index]) {
  487. free_mcts(tree);
  488. tree = 0;
  489. } else {
  490. mcts_tree *swap = tree;
  491. tree = tree->children[index];
  492. swap->children[index] = 0;
  493. free_mcts(swap);
  494. }
  495. return tree;
  496. }
  497. typedef struct {
  498. float value;
  499. float mcts;
  500. int row;
  501. int col;
  502. } move;
  503. move pick_move(mcts_tree *tree, float temp, int player)
  504. {
  505. int i;
  506. float probs[19*19+1] = {0};
  507. move m = {0};
  508. double sum = 0;
  509. /*
  510. for(i = 0; i < 19*19+1; ++i){
  511. probs[i] = tree->visit_count[i];
  512. }
  513. */
  514. //softmax(probs, 19*19+1, temp, 1, probs);
  515. for(i = 0; i < 19*19+1; ++i){
  516. sum += pow(tree->visit_count[i], 1./temp);
  517. }
  518. for(i = 0; i < 19*19+1; ++i){
  519. probs[i] = pow(tree->visit_count[i], 1./temp) / sum;
  520. }
  521. int index = sample_array(probs, 19*19+1);
  522. m.row = index / 19;
  523. m.col = index % 19;
  524. m.value = (tree->result+1.)/2.;
  525. m.mcts = (tree->mean[index]+1.)/2.;
  526. int indexes[nind];
  527. top_k(probs, 19*19+1, nind, indexes);
  528. print_board(stderr, tree->board, player, indexes);
  529. fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", index/19, index%19, tree->result, tree->prior[index], probs[index], tree->mean[index], (tree->children[index])?tree->children[index]->result:0, tree->visit_count[index]);
  530. int ind = max_index(probs, 19*19+1);
  531. fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", ind/19, ind%19, tree->result, tree->prior[ind], probs[ind], tree->mean[ind], (tree->children[ind])?tree->children[ind]->result:0, tree->visit_count[ind]);
  532. ind = max_index(tree->prior, 19*19+1);
  533. fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", ind/19, ind%19, tree->result, tree->prior[ind], probs[ind], tree->mean[ind], (tree->children[ind])?tree->children[ind]->result:0, tree->visit_count[ind]);
  534. return m;
  535. }
  536. /*
  537. float predict_move(network *net, float *board, float *move, int multi, float *ko, float temp)
  538. {
  539. int i;
  540. int max_v = 0;
  541. int max_i = 0;
  542. for(i = 0; i < 19*19+1; ++i){
  543. if(root->visit_count[i] > max_v){
  544. max_v = root->visit_count[i];
  545. max_i = i;
  546. }
  547. }
  548. fprintf(stderr, "%f Seconds\n", what_time_is_it_now() - t);
  549. int ind = max_index(root->mean, 19*19+1);
  550. fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", max_i/19, max_i%19, root->result, root->prior[max_i], root->prob[max_i], root->mean[max_i], (root->children[max_i])?root->children[max_i]->result:0, root->visit_count[max_i]);
  551. fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", ind/19, ind%19, root->result, root->prior[ind], root->prob[ind], root->mean[ind], (root->children[ind])?root->children[ind]->result:0, root->visit_count[ind]);
  552. ind = max_index(root->prior, 19*19+1);
  553. fprintf(stderr, "%d %d, Result: %f, Prior: %f, Prob: %f, Mean Value: %f, Child Result: %f, Visited: %d\n", ind/19, ind%19, root->result, root->prior[ind], root->prob[ind], root->mean[ind], (root->children[ind])?root->children[ind]->result:0, root->visit_count[ind]);
  554. if(root->result < -.9 && root->mean[max_i] < -.9) return -1000.f;
  555. float val = root->result;
  556. free_mcts(root);
  557. return val;
  558. }
  559. */
  560. static int makes_safe_go(float *b, int *lib, int p, int r, int c){
  561. if (r < 0 || r >= 19 || c < 0 || c >= 19) return 0;
  562. if (occupied(b,r*19 + c) == -p){
  563. if (lib[r*19 + c] > 1) return 0;
  564. else return 1;
  565. }
  566. if (!occupied(b,r*19 + c)) return 1;
  567. if (lib[r*19 + c] > 1) return 1;
  568. return 0;
  569. }
  570. int suicide_go(float *b, int p, int r, int c)
  571. {
  572. int *l = calculate_liberties(b);
  573. int safe = 0;
  574. safe = safe || makes_safe_go(b, l, p, r+1, c);
  575. safe = safe || makes_safe_go(b, l, p, r-1, c);
  576. safe = safe || makes_safe_go(b, l, p, r, c+1);
  577. safe = safe || makes_safe_go(b, l, p, r, c-1);
  578. free(l);
  579. return !safe;
  580. }
  581. int check_ko(float *x, float *ko)
  582. {
  583. if(!ko) return 0;
  584. float curr[19*19*3];
  585. copy_cpu(19*19*3, x, 1, curr, 1);
  586. if(curr[19*19*2] != ko[19*19*2]) flip_board(curr);
  587. if(compare_board(curr, ko)) return 1;
  588. return 0;
  589. }
  590. int legal_go(float *b, float *ko, int p, int r, int c)
  591. {
  592. if (occupied(b, r*19+c)) return 0;
  593. float curr[19*19*3];
  594. copy_cpu(19*19*3, b, 1, curr, 1);
  595. move_go(curr, p, r, c);
  596. if(check_ko(curr, ko)) return 0;
  597. if(suicide_go(b, p, r, c)) return 0;
  598. return 1;
  599. }
  600. /*
  601. move generate_move(mcts_tree *root, network *net, int player, float *board, int multi, float temp, float *ko, int print)
  602. {
  603. move m = {0};
  604. //root = run_mcts(tree, network *net, float *board, float *ko, int n, float cpuct)
  605. int i, j;
  606. int empty = 1;
  607. for(i = 0; i < 19*19; ++i){
  608. if (occupied(board, i)) {
  609. empty = 0;
  610. break;
  611. }
  612. }
  613. if(empty) {
  614. m.value = .5;
  615. m.mcts = .5;
  616. m.row = 3;
  617. m.col = 15;
  618. return m;
  619. }
  620. float move[362];
  621. if (player < 0) flip_board(board);
  622. float result = predict_move(net, board, move, multi, ko, temp);
  623. if (player < 0) flip_board(board);
  624. if(result == -1000.f) return -2;
  625. for(i = 0; i < 19; ++i){
  626. for(j = 0; j < 19; ++j){
  627. if (!legal_go(board, ko, player, i, j)) move[i*19 + j] = 0;
  628. }
  629. }
  630. int indexes[nind];
  631. top_k(move, 19*19+1, nind, indexes);
  632. int max = max_index(move, 19*19+1);
  633. int row = max / 19;
  634. int col = max % 19;
  635. int index = sample_array(move, 19*19+1);
  636. if(print){
  637. top_k(move, 19*19+1, nind, indexes);
  638. for(i = 0; i < nind; ++i){
  639. if (!move[indexes[i]]) indexes[i] = -1;
  640. }
  641. print_board(stderr, board, 1, indexes);
  642. fprintf(stderr, "%s To Move\n", player > 0 ? "X" : "O");
  643. fprintf(stderr, "%.2f%% Win Chance\n", (result+1)/2*100);
  644. for(i = 0; i < nind; ++i){
  645. int index = indexes[i];
  646. int row = index / 19;
  647. int col = index % 19;
  648. if(row == 19){
  649. fprintf(stderr, "%d: Pass, %.2f%%\n", i+1, move[index]*100);
  650. } else {
  651. fprintf(stderr, "%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
  652. }
  653. }
  654. }
  655. if (row == 19) return -1;
  656. if (suicide_go(board, player, row, col)){
  657. return -1;
  658. }
  659. if (suicide_go(board, player, index/19, index%19)){
  660. index = max;
  661. }
  662. if (index == 19*19) return -1;
  663. return index;
  664. }
  665. */
  666. void valid_go(char *cfgfile, char *weightfile, int multi, char *filename)
  667. {
  668. srand(time(0));
  669. char *base = basecfg(cfgfile);
  670. printf("%s\n", base);
  671. network *net = load_network(cfgfile, weightfile, 0);
  672. set_batch_network(net, 1);
  673. printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
  674. float *board = calloc(19*19*3, sizeof(float));
  675. float *move = calloc(19*19+2, sizeof(float));
  676. // moves m = load_go_moves("/home/pjreddie/backup/go.test");
  677. moves m = load_go_moves(filename);
  678. int N = m.n;
  679. int i,j;
  680. int correct = 0;
  681. for (i = 0; i <N; ++i) {
  682. char *b = m.data[i];
  683. int player = b[0] - '0';
  684. //int result = b[1] - '0';
  685. int row = b[2];
  686. int col = b[3];
  687. int truth = col + 19*row;
  688. string_to_board(b+4, board);
  689. if(player > 0) for(j = 0; j < 19*19; ++j) board[19*19*2 + j] = 1;
  690. predict_move2(net, board, move, multi);
  691. int index = max_index(move, 19*19+1);
  692. if(index == truth) ++correct;
  693. printf("%d Accuracy %f\n", i, (float) correct/(i+1));
  694. }
  695. }
  696. int print_game(float *board, FILE *fp)
  697. {
  698. int i, j;
  699. int count = 3;
  700. fprintf(fp, "komi 6.5\n");
  701. fprintf(fp, "boardsize 19\n");
  702. fprintf(fp, "clear_board\n");
  703. for(j = 0; j < 19; ++j){
  704. for(i = 0; i < 19; ++i){
  705. if(occupied(board,j*19 + i) == 1) fprintf(fp, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
  706. if(occupied(board,j*19 + i) == -1) fprintf(fp, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
  707. if(occupied(board,j*19 + i)) ++count;
  708. }
  709. }
  710. return count;
  711. }
  712. int stdin_ready()
  713. {
  714. fd_set readfds;
  715. FD_ZERO(&readfds);
  716. struct timeval timeout;
  717. timeout.tv_sec = 0;
  718. timeout.tv_usec = 0;
  719. FD_SET(STDIN_FILENO, &readfds);
  720. if (select(1, &readfds, NULL, NULL, &timeout)){
  721. return 1;
  722. }
  723. return 0;
  724. }
  725. mcts_tree *ponder(mcts_tree *tree, network *net, float *b, float *ko, int player, float cpuct)
  726. {
  727. double t = what_time_is_it_now();
  728. int count = 0;
  729. if (tree) count = tree->total_count;
  730. while(!stdin_ready()){
  731. if (what_time_is_it_now() - t > 120) break;
  732. tree = run_mcts(tree, net, b, ko, player, 100000, cpuct, .1);
  733. }
  734. fprintf(stderr, "Pondered %d moves...\n", tree->total_count - count);
  735. return tree;
  736. }
  737. void engine_go(char *filename, char *weightfile, int mcts_iters, float secs, float temp, float cpuct, int anon, int resign)
  738. {
  739. mcts_tree *root = 0;
  740. network *net = load_network(filename, weightfile, 0);
  741. set_batch_network(net, 1);
  742. srand(time(0));
  743. float *board = calloc(19*19*3, sizeof(float));
  744. flip_board(board);
  745. float *one = calloc(19*19*3, sizeof(float));
  746. float *two = calloc(19*19*3, sizeof(float));
  747. int ponder_player = 0;
  748. int passed = 0;
  749. int move_num = 0;
  750. int main_time = 0;
  751. int byo_yomi_time = 0;
  752. int byo_yomi_stones = 0;
  753. int black_time_left = 0;
  754. int black_stones_left = 0;
  755. int white_time_left = 0;
  756. int white_stones_left = 0;
  757. float orig_time = secs;
  758. int old_ponder = 0;
  759. while(1){
  760. if(ponder_player){
  761. root = ponder(root, net, board, two, ponder_player, cpuct);
  762. }
  763. old_ponder = ponder_player;
  764. ponder_player = 0;
  765. char buff[256];
  766. int id = 0;
  767. int has_id = (scanf("%d", &id) == 1);
  768. scanf("%s", buff);
  769. if (feof(stdin)) break;
  770. fprintf(stderr, "%s\n", buff);
  771. char ids[256];
  772. sprintf(ids, "%d", id);
  773. //fprintf(stderr, "%s\n", buff);
  774. if (!has_id) ids[0] = 0;
  775. if (!strcmp(buff, "protocol_version")){
  776. printf("=%s 2\n\n", ids);
  777. } else if (!strcmp(buff, "name")){
  778. if(anon){
  779. printf("=%s The Fool!\n\n", ids);
  780. }else{
  781. printf("=%s DarkGo\n\n", ids);
  782. }
  783. } else if (!strcmp(buff, "time_settings")){
  784. ponder_player = old_ponder;
  785. scanf("%d %d %d", &main_time, &byo_yomi_time, &byo_yomi_stones);
  786. printf("=%s \n\n", ids);
  787. } else if (!strcmp(buff, "time_left")){
  788. ponder_player = old_ponder;
  789. char color[256];
  790. int time = 0, stones = 0;
  791. scanf("%s %d %d", color, &time, &stones);
  792. if (color[0] == 'b' || color[0] == 'B'){
  793. black_time_left = time;
  794. black_stones_left = stones;
  795. } else {
  796. white_time_left = time;
  797. white_stones_left = stones;
  798. }
  799. printf("=%s \n\n", ids);
  800. } else if (!strcmp(buff, "version")){
  801. if(anon){
  802. printf("=%s :-DDDD\n\n", ids);
  803. }else {
  804. printf("=%s 1.0. Want more DarkGo? You can find me on OGS, unlimited games, no waiting! https://online-go.com/user/view/434218\n\n", ids);
  805. }
  806. } else if (!strcmp(buff, "known_command")){
  807. char comm[256];
  808. scanf("%s", comm);
  809. int known = (!strcmp(comm, "protocol_version") ||
  810. !strcmp(comm, "name") ||
  811. !strcmp(comm, "version") ||
  812. !strcmp(comm, "known_command") ||
  813. !strcmp(comm, "list_commands") ||
  814. !strcmp(comm, "quit") ||
  815. !strcmp(comm, "boardsize") ||
  816. !strcmp(comm, "clear_board") ||
  817. !strcmp(comm, "komi") ||
  818. !strcmp(comm, "final_status_list") ||
  819. !strcmp(comm, "play") ||
  820. !strcmp(comm, "genmove_white") ||
  821. !strcmp(comm, "genmove_black") ||
  822. !strcmp(comm, "fixed_handicap") ||
  823. !strcmp(comm, "genmove"));
  824. if(known) printf("=%s true\n\n", ids);
  825. else printf("=%s false\n\n", ids);
  826. } else if (!strcmp(buff, "list_commands")){
  827. printf("=%s protocol_version\nshowboard\nname\nversion\nknown_command\nlist_commands\nquit\nboardsize\nclear_board\nkomi\nplay\ngenmove_black\ngenmove_white\ngenmove\nfinal_status_list\nfixed_handicap\n\n", ids);
  828. } else if (!strcmp(buff, "quit")){
  829. break;
  830. } else if (!strcmp(buff, "boardsize")){
  831. int boardsize = 0;
  832. scanf("%d", &boardsize);
  833. //fprintf(stderr, "%d\n", boardsize);
  834. if(boardsize != 19){
  835. printf("?%s unacceptable size\n\n", ids);
  836. } else {
  837. root = move_mcts(root, -1);
  838. memset(board, 0, 3*19*19*sizeof(float));
  839. flip_board(board);
  840. move_num = 0;
  841. printf("=%s \n\n", ids);
  842. }
  843. } else if (!strcmp(buff, "fixed_handicap")){
  844. int handicap = 0;
  845. scanf("%d", &handicap);
  846. int indexes[] = {72, 288, 300, 60, 180, 174, 186, 66, 294};
  847. int i;
  848. for(i = 0; i < handicap; ++i){
  849. board[indexes[i]] = 1;
  850. ++move_num;
  851. }
  852. root = move_mcts(root, -1);
  853. } else if (!strcmp(buff, "clear_board")){
  854. passed = 0;
  855. memset(board, 0, 3*19*19*sizeof(float));
  856. flip_board(board);
  857. move_num = 0;
  858. root = move_mcts(root, -1);
  859. printf("=%s \n\n", ids);
  860. } else if (!strcmp(buff, "komi")){
  861. float komi = 0;
  862. scanf("%f", &komi);
  863. printf("=%s \n\n", ids);
  864. } else if (!strcmp(buff, "showboard")){
  865. printf("=%s \n", ids);
  866. print_board(stdout, board, 1, 0);
  867. printf("\n");
  868. } else if (!strcmp(buff, "play") || !strcmp(buff, "black") || !strcmp(buff, "white")){
  869. ++move_num;
  870. char color[256];
  871. if(!strcmp(buff, "play"))
  872. {
  873. scanf("%s ", color);
  874. } else {
  875. scanf(" ");
  876. color[0] = buff[0];
  877. }
  878. char c;
  879. int r;
  880. int count = scanf("%c%d", &c, &r);
  881. int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
  882. if((c == 'p' || c == 'P') && count < 2) {
  883. passed = 1;
  884. printf("=%s \n\n", ids);
  885. char *line = fgetl(stdin);
  886. free(line);
  887. fflush(stdout);
  888. fflush(stderr);
  889. root = move_mcts(root, 19*19);
  890. continue;
  891. } else {
  892. passed = 0;
  893. }
  894. if(c >= 'A' && c <= 'Z') c = c - 'A';
  895. if(c >= 'a' && c <= 'z') c = c - 'a';
  896. if(c >= 8) --c;
  897. r = 19 - r;
  898. fprintf(stderr, "move: %d %d\n", r, c);
  899. float *swap = two;
  900. two = one;
  901. one = swap;
  902. move_go(board, player, r, c);
  903. copy_cpu(19*19*3, board, 1, one, 1);
  904. if(root) fprintf(stderr, "Prior: %f\n", root->prior[r*19 + c]);
  905. if(root) fprintf(stderr, "Mean: %f\n", root->mean[r*19 + c]);
  906. if(root) fprintf(stderr, "Result: %f\n", root->result);
  907. root = move_mcts(root, r*19 + c);
  908. if(root) fprintf(stderr, "Visited: %d\n", root->total_count);
  909. else fprintf(stderr, "NOT VISITED\n");
  910. printf("=%s \n\n", ids);
  911. //print_board(stderr, board, 1, 0);
  912. } else if (!strcmp(buff, "genmove") || !strcmp(buff, "genmove_black") || !strcmp(buff, "genmove_white")){
  913. ++move_num;
  914. int player = 0;
  915. if(!strcmp(buff, "genmove")){
  916. char color[256];
  917. scanf("%s", color);
  918. player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
  919. } else if (!strcmp(buff, "genmove_black")){
  920. player = 1;
  921. } else {
  922. player = -1;
  923. }
  924. if(player > 0){
  925. if(black_time_left <= 30) secs = 2.5;
  926. else secs = orig_time;
  927. } else {
  928. if(white_time_left <= 30) secs = 2.5;
  929. else secs = orig_time;
  930. }
  931. ponder_player = -player;
  932. //tree = generate_move(net, player, board, multi, .1, two, 1);
  933. double t = what_time_is_it_now();
  934. root = run_mcts(root, net, board, two, player, mcts_iters, cpuct, secs);
  935. fprintf(stderr, "%f Seconds\n", what_time_is_it_now() - t);
  936. move m = pick_move(root, temp, player);
  937. root = move_mcts(root, m.row*19 + m.col);
  938. if(move_num > resign && m.value < .1 && m.mcts < .1){
  939. printf("=%s resign\n\n", ids);
  940. } else if(m.row == 19){
  941. printf("=%s pass\n\n", ids);
  942. passed = 0;
  943. } else {
  944. int row = m.row;
  945. int col = m.col;
  946. float *swap = two;
  947. two = one;
  948. one = swap;
  949. move_go(board, player, row, col);
  950. copy_cpu(19*19*3, board, 1, one, 1);
  951. row = 19 - row;
  952. if (col >= 8) ++col;
  953. printf("=%s %c%d\n\n", ids, 'A' + col, row);
  954. }
  955. } else if (!strcmp(buff, "p")){
  956. //print_board(board, 1, 0);
  957. } else if (!strcmp(buff, "final_status_list")){
  958. char type[256];
  959. scanf("%s", type);
  960. fprintf(stderr, "final_status\n");
  961. char *line = fgetl(stdin);
  962. free(line);
  963. if(type[0] == 'd' || type[0] == 'D'){
  964. int i;
  965. FILE *f = fopen("game.txt", "w");
  966. int count = print_game(board, f);
  967. fprintf(f, "%s final_status_list dead\n", ids);
  968. fclose(f);
  969. FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
  970. for(i = 0; i < count; ++i){
  971. free(fgetl(p));
  972. free(fgetl(p));
  973. }
  974. char *l = 0;
  975. while((l = fgetl(p))){
  976. printf("%s\n", l);
  977. free(l);
  978. }
  979. } else {
  980. printf("?%s unknown command\n\n", ids);
  981. }
  982. } else if (!strcmp(buff, "kgs-genmove_cleanup")){
  983. char type[256];
  984. scanf("%s", type);
  985. fprintf(stderr, "kgs-genmove_cleanup\n");
  986. char *line = fgetl(stdin);
  987. free(line);
  988. int i;
  989. FILE *f = fopen("game.txt", "w");
  990. int count = print_game(board, f);
  991. fprintf(f, "%s kgs-genmove_cleanup %s\n", ids, type);
  992. fclose(f);
  993. FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
  994. for(i = 0; i < count; ++i){
  995. free(fgetl(p));
  996. free(fgetl(p));
  997. }
  998. char *l = 0;
  999. while((l = fgetl(p))){
  1000. printf("%s\n", l);
  1001. free(l);
  1002. }
  1003. } else {
  1004. char *line = fgetl(stdin);
  1005. free(line);
  1006. printf("?%s unknown command\n\n", ids);
  1007. }
  1008. fflush(stdout);
  1009. fflush(stderr);
  1010. }
  1011. printf("%d %d %d\n",passed, black_stones_left, white_stones_left);
  1012. }
  1013. void test_go(char *cfg, char *weights, int multi)
  1014. {
  1015. int i;
  1016. network *net = load_network(cfg, weights, 0);
  1017. set_batch_network(net, 1);
  1018. srand(time(0));
  1019. float *board = calloc(19*19*3, sizeof(float));
  1020. flip_board(board);
  1021. float *move = calloc(19*19+1, sizeof(float));
  1022. int color = 1;
  1023. while(1){
  1024. float result = predict_move2(net, board, move, multi);
  1025. printf("%.2f%% Win Chance\n", (result+1)/2*100);
  1026. int indexes[nind];
  1027. int row, col;
  1028. top_k(move, 19*19+1, nind, indexes);
  1029. print_board(stderr, board, color, indexes);
  1030. for(i = 0; i < nind; ++i){
  1031. int index = indexes[i];
  1032. row = index / 19;
  1033. col = index % 19;
  1034. if(row == 19){
  1035. printf("%d: Pass, %.2f%%\n", i+1, move[index]*100);
  1036. } else {
  1037. printf("%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
  1038. }
  1039. }
  1040. //if(color == 1) printf("\u25EF Enter move: ");
  1041. //else printf("\u25C9 Enter move: ");
  1042. if(color == 1) printf("X Enter move: ");
  1043. else printf("O Enter move: ");
  1044. char c;
  1045. char *line = fgetl(stdin);
  1046. int picked = 1;
  1047. int dnum = sscanf(line, "%d", &picked);
  1048. int cnum = sscanf(line, "%c", &c);
  1049. if (strlen(line) == 0 || dnum) {
  1050. --picked;
  1051. if (picked < nind){
  1052. int index = indexes[picked];
  1053. row = index / 19;
  1054. col = index % 19;
  1055. if(row < 19){
  1056. move_go(board, 1, row, col);
  1057. }
  1058. }
  1059. } else if (cnum){
  1060. if (c <= 'T' && c >= 'A'){
  1061. int num = sscanf(line, "%c %d", &c, &row);
  1062. row = (inverted)?19 - row : row-1;
  1063. col = c - 'A';
  1064. if (col > 7 && noi) col -= 1;
  1065. if (num == 2) move_go(board, 1, row, col);
  1066. } else if (c == 'p') {
  1067. // Pass
  1068. } else if(c=='b' || c == 'w'){
  1069. char g;
  1070. int num = sscanf(line, "%c %c %d", &g, &c, &row);
  1071. row = (inverted)?19 - row : row-1;
  1072. col = c - 'A';
  1073. if (col > 7 && noi) col -= 1;
  1074. if (num == 3) {
  1075. int mc = (g == 'b') ? 1 : -1;
  1076. if (mc == color) {
  1077. board[row*19 + col] = 1;
  1078. } else {
  1079. board[19*19 + row*19 + col] = 1;
  1080. }
  1081. }
  1082. } else if(c == 'c'){
  1083. char g;
  1084. int num = sscanf(line, "%c %c %d", &g, &c, &row);
  1085. row = (inverted)?19 - row : row-1;
  1086. col = c - 'A';
  1087. if (col > 7 && noi) col -= 1;
  1088. if (num == 3) {
  1089. board[row*19 + col] = 0;
  1090. board[19*19 + row*19 + col] = 0;
  1091. }
  1092. }
  1093. }
  1094. free(line);
  1095. flip_board(board);
  1096. color = -color;
  1097. }
  1098. }
  1099. float score_game(float *board)
  1100. {
  1101. int i;
  1102. FILE *f = fopen("game.txt", "w");
  1103. int count = print_game(board, f);
  1104. fprintf(f, "final_score\n");
  1105. fclose(f);
  1106. FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
  1107. for(i = 0; i < count; ++i){
  1108. free(fgetl(p));
  1109. free(fgetl(p));
  1110. }
  1111. char *l = 0;
  1112. float score = 0;
  1113. char player = 0;
  1114. while((l = fgetl(p))){
  1115. fprintf(stderr, "%s \t", l);
  1116. int n = sscanf(l, "= %c+%f", &player, &score);
  1117. free(l);
  1118. if (n == 2) break;
  1119. }
  1120. if(player == 'W') score = -score;
  1121. pclose(p);
  1122. return score;
  1123. }
  1124. void self_go(char *filename, char *weightfile, char *f2, char *w2, int multi)
  1125. {
  1126. mcts_tree *tree1 = 0;
  1127. mcts_tree *tree2 = 0;
  1128. network *net = load_network(filename, weightfile, 0);
  1129. //set_batch_network(net, 1);
  1130. network *net2;
  1131. if (f2) {
  1132. net2 = parse_network_cfg(f2);
  1133. if(w2){
  1134. load_weights(net2, w2);
  1135. }
  1136. } else {
  1137. net2 = calloc(1, sizeof(network));
  1138. *net2 = *net;
  1139. }
  1140. srand(time(0));
  1141. char boards[600][93];
  1142. int count = 0;
  1143. //set_batch_network(net, 1);
  1144. //set_batch_network(net2, 1);
  1145. float *board = calloc(19*19*3, sizeof(float));
  1146. flip_board(board);
  1147. float *one = calloc(19*19*3, sizeof(float));
  1148. float *two = calloc(19*19*3, sizeof(float));
  1149. int done = 0;
  1150. int player = 1;
  1151. int p1 = 0;
  1152. int p2 = 0;
  1153. int total = 0;
  1154. float temp = .1;
  1155. int mcts_iters = 500;
  1156. float cpuct = 5;
  1157. while(1){
  1158. if (done){
  1159. tree1 = move_mcts(tree1, -1);
  1160. tree2 = move_mcts(tree2, -1);
  1161. float score = score_game(board);
  1162. if((score > 0) == (total%2==0)) ++p1;
  1163. else ++p2;
  1164. ++total;
  1165. fprintf(stderr, "Total: %d, Player 1: %f, Player 2: %f\n", total, (float)p1/total, (float)p2/total);
  1166. sleep(1);
  1167. /*
  1168. int i = (score > 0)? 0 : 1;
  1169. int j;
  1170. for(; i < count; i += 2){
  1171. for(j = 0; j < 93; ++j){
  1172. printf("%c", boards[i][j]);
  1173. }
  1174. printf("\n");
  1175. }
  1176. */
  1177. memset(board, 0, 3*19*19*sizeof(float));
  1178. flip_board(board);
  1179. player = 1;
  1180. done = 0;
  1181. count = 0;
  1182. fflush(stdout);
  1183. fflush(stderr);
  1184. }
  1185. //print_board(stderr, board, 1, 0);
  1186. //sleep(1);
  1187. if ((total%2==0) == (player==1)){
  1188. //mcts_iters = 4500;
  1189. cpuct = 5;
  1190. } else {
  1191. //mcts_iters = 500;
  1192. cpuct = 1;
  1193. }
  1194. network *use = ((total%2==0) == (player==1)) ? net : net2;
  1195. mcts_tree *t = ((total%2==0) == (player==1)) ? tree1 : tree2;
  1196. t = run_mcts(t, use, board, two, player, mcts_iters, cpuct, 0);
  1197. move m = pick_move(t, temp, player);
  1198. if(((total%2==0) == (player==1))) tree1 = t;
  1199. else tree2 = t;
  1200. tree1 = move_mcts(tree1, m.row*19 + m.col);
  1201. tree2 = move_mcts(tree2, m.row*19 + m.col);
  1202. if(m.row == 19){
  1203. done = 1;
  1204. continue;
  1205. }
  1206. int row = m.row;
  1207. int col = m.col;
  1208. float *swap = two;
  1209. two = one;
  1210. one = swap;
  1211. if(player < 0) flip_board(board);
  1212. boards[count][0] = row;
  1213. boards[count][1] = col;
  1214. board_to_string(boards[count] + 2, board);
  1215. if(player < 0) flip_board(board);
  1216. ++count;
  1217. move_go(board, player, row, col);
  1218. copy_cpu(19*19*3, board, 1, one, 1);
  1219. player = -player;
  1220. }
  1221. }
  1222. void run_go(int argc, char **argv)
  1223. {
  1224. //boards_go();
  1225. if(argc < 4){
  1226. fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
  1227. return;
  1228. }
  1229. char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
  1230. int *gpus = 0;
  1231. int gpu = 0;
  1232. int ngpus = 0;
  1233. if(gpu_list){
  1234. printf("%s\n", gpu_list);
  1235. int len = strlen(gpu_list);
  1236. ngpus = 1;
  1237. int i;
  1238. for(i = 0; i < len; ++i){
  1239. if (gpu_list[i] == ',') ++ngpus;
  1240. }
  1241. gpus = calloc(ngpus, sizeof(int));
  1242. for(i = 0; i < ngpus; ++i){
  1243. gpus[i] = atoi(gpu_list);
  1244. gpu_list = strchr(gpu_list, ',')+1;
  1245. }
  1246. } else {
  1247. gpu = gpu_index;
  1248. gpus = &gpu;
  1249. ngpus = 1;
  1250. }
  1251. int clear = find_arg(argc, argv, "-clear");
  1252. char *cfg = argv[3];
  1253. char *weights = (argc > 4) ? argv[4] : 0;
  1254. char *c2 = (argc > 5) ? argv[5] : 0;
  1255. char *w2 = (argc > 6) ? argv[6] : 0;
  1256. int multi = find_arg(argc, argv, "-multi");
  1257. int anon = find_arg(argc, argv, "-anon");
  1258. int iters = find_int_arg(argc, argv, "-iters", 500);
  1259. int resign = find_int_arg(argc, argv, "-resign", 175);
  1260. float cpuct = find_float_arg(argc, argv, "-cpuct", 5);
  1261. float temp = find_float_arg(argc, argv, "-temp", .1);
  1262. float time = find_float_arg(argc, argv, "-time", 0);
  1263. if(0==strcmp(argv[2], "train")) train_go(cfg, weights, c2, gpus, ngpus, clear);
  1264. else if(0==strcmp(argv[2], "valid")) valid_go(cfg, weights, multi, c2);
  1265. else if(0==strcmp(argv[2], "self")) self_go(cfg, weights, c2, w2, multi);
  1266. else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
  1267. else if(0==strcmp(argv[2], "engine")) engine_go(cfg, weights, iters, time, temp, cpuct, anon, resign);
  1268. }