1
0

proverbot.py 900 B

12345678910111213141516171819202122232425262728293031323334353637
  1. from darknet import *
  2. def predict_tactic(net, s):
  3. prob = 0
  4. d = c_array(c_float, [0.0]*256)
  5. tac = ''
  6. if not len(s):
  7. s = '\n'
  8. for c in s[:-1]:
  9. d[ord(c)] = 1
  10. pred = predict(net, d)
  11. d[ord(c)] = 0
  12. c = s[-1]
  13. while 1:
  14. d[ord(c)] = 1
  15. pred = predict(net, d)
  16. d[ord(c)] = 0
  17. pred = [pred[i] for i in range(256)]
  18. ind = sample(pred)
  19. c = chr(ind)
  20. prob += math.log(pred[ind])
  21. if len(tac) and tac[-1] == '.':
  22. break
  23. tac = tac + c
  24. return (tac, prob)
  25. def predict_tactics(net, s, n):
  26. tacs = []
  27. for i in range(n):
  28. reset_rnn(net)
  29. tacs.append(predict_tactic(net, s))
  30. tacs = sorted(tacs, key=lambda x: -x[1])
  31. return tacs
  32. net = load_net("cfg/coq.test.cfg", "/home/pjreddie/backup/coq.backup", 0)
  33. t = predict_tactics(net, "+++++\n", 10)
  34. print t