12345678910111213141516171819202122232425262728293031323334353637 |
- from darknet import *
- def predict_tactic(net, s):
- prob = 0
- d = c_array(c_float, [0.0]*256)
- tac = ''
- if not len(s):
- s = '\n'
- for c in s[:-1]:
- d[ord(c)] = 1
- pred = predict(net, d)
- d[ord(c)] = 0
- c = s[-1]
- while 1:
- d[ord(c)] = 1
- pred = predict(net, d)
- d[ord(c)] = 0
- pred = [pred[i] for i in range(256)]
- ind = sample(pred)
- c = chr(ind)
- prob += math.log(pred[ind])
- if len(tac) and tac[-1] == '.':
- break
- tac = tac + c
- return (tac, prob)
- def predict_tactics(net, s, n):
- tacs = []
- for i in range(n):
- reset_rnn(net)
- tacs.append(predict_tactic(net, s))
- tacs = sorted(tacs, key=lambda x: -x[1])
- return tacs
- net = load_net("cfg/coq.test.cfg", "/home/pjreddie/backup/coq.backup", 0)
- t = predict_tactics(net, "+++++\n", 10)
- print t
|