1
0

darknet.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from ctypes import *
  2. import math
  3. import random
  4. def sample(probs):
  5. s = sum(probs)
  6. probs = [a/s for a in probs]
  7. r = random.uniform(0, 1)
  8. for i in range(len(probs)):
  9. r = r - probs[i]
  10. if r <= 0:
  11. return i
  12. return len(probs)-1
  13. def c_array(ctype, values):
  14. arr = (ctype*len(values))()
  15. arr[:] = values
  16. return arr
  17. class BOX(Structure):
  18. _fields_ = [("x", c_float),
  19. ("y", c_float),
  20. ("w", c_float),
  21. ("h", c_float)]
  22. class DETECTION(Structure):
  23. _fields_ = [("bbox", BOX),
  24. ("classes", c_int),
  25. ("prob", POINTER(c_float)),
  26. ("mask", POINTER(c_float)),
  27. ("objectness", c_float),
  28. ("sort_class", c_int)]
  29. class IMAGE(Structure):
  30. _fields_ = [("w", c_int),
  31. ("h", c_int),
  32. ("c", c_int),
  33. ("data", POINTER(c_float))]
  34. class METADATA(Structure):
  35. _fields_ = [("classes", c_int),
  36. ("names", POINTER(c_char_p))]
  37. #lib = CDLL("/home/pjreddie/documents/darknet/libdarknet.so", RTLD_GLOBAL)
  38. lib = CDLL("libdarknet.so", RTLD_GLOBAL)
  39. lib.network_width.argtypes = [c_void_p]
  40. lib.network_width.restype = c_int
  41. lib.network_height.argtypes = [c_void_p]
  42. lib.network_height.restype = c_int
  43. predict = lib.network_predict
  44. predict.argtypes = [c_void_p, POINTER(c_float)]
  45. predict.restype = POINTER(c_float)
  46. set_gpu = lib.cuda_set_device
  47. set_gpu.argtypes = [c_int]
  48. make_image = lib.make_image
  49. make_image.argtypes = [c_int, c_int, c_int]
  50. make_image.restype = IMAGE
  51. get_network_boxes = lib.get_network_boxes
  52. get_network_boxes.argtypes = [c_void_p, c_int, c_int, c_float, c_float, POINTER(c_int), c_int, POINTER(c_int)]
  53. get_network_boxes.restype = POINTER(DETECTION)
  54. make_network_boxes = lib.make_network_boxes
  55. make_network_boxes.argtypes = [c_void_p]
  56. make_network_boxes.restype = POINTER(DETECTION)
  57. free_detections = lib.free_detections
  58. free_detections.argtypes = [POINTER(DETECTION), c_int]
  59. free_ptrs = lib.free_ptrs
  60. free_ptrs.argtypes = [POINTER(c_void_p), c_int]
  61. network_predict = lib.network_predict
  62. network_predict.argtypes = [c_void_p, POINTER(c_float)]
  63. reset_rnn = lib.reset_rnn
  64. reset_rnn.argtypes = [c_void_p]
  65. load_net = lib.load_network
  66. load_net.argtypes = [c_char_p, c_char_p, c_int]
  67. load_net.restype = c_void_p
  68. do_nms_obj = lib.do_nms_obj
  69. do_nms_obj.argtypes = [POINTER(DETECTION), c_int, c_int, c_float]
  70. do_nms_sort = lib.do_nms_sort
  71. do_nms_sort.argtypes = [POINTER(DETECTION), c_int, c_int, c_float]
  72. free_image = lib.free_image
  73. free_image.argtypes = [IMAGE]
  74. letterbox_image = lib.letterbox_image
  75. letterbox_image.argtypes = [IMAGE, c_int, c_int]
  76. letterbox_image.restype = IMAGE
  77. load_meta = lib.get_metadata
  78. lib.get_metadata.argtypes = [c_char_p]
  79. lib.get_metadata.restype = METADATA
  80. load_image = lib.load_image_color
  81. load_image.argtypes = [c_char_p, c_int, c_int]
  82. load_image.restype = IMAGE
  83. rgbgr_image = lib.rgbgr_image
  84. rgbgr_image.argtypes = [IMAGE]
  85. predict_image = lib.network_predict_image
  86. predict_image.argtypes = [c_void_p, IMAGE]
  87. predict_image.restype = POINTER(c_float)
  88. def classify(net, meta, im):
  89. out = predict_image(net, im)
  90. res = []
  91. for i in range(meta.classes):
  92. res.append((meta.names[i], out[i]))
  93. res = sorted(res, key=lambda x: -x[1])
  94. return res
  95. def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45):
  96. im = load_image(image, 0, 0)
  97. num = c_int(0)
  98. pnum = pointer(num)
  99. predict_image(net, im)
  100. dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, None, 0, pnum)
  101. num = pnum[0]
  102. if (nms): do_nms_obj(dets, num, meta.classes, nms);
  103. res = []
  104. for j in range(num):
  105. for i in range(meta.classes):
  106. if dets[j].prob[i] > 0:
  107. b = dets[j].bbox
  108. res.append((meta.names[i], dets[j].prob[i], (b.x, b.y, b.w, b.h)))
  109. res = sorted(res, key=lambda x: -x[1])
  110. free_image(im)
  111. free_detections(dets, num)
  112. return res
  113. if __name__ == "__main__":
  114. #net = load_net("cfg/densenet201.cfg", "/home/pjreddie/trained/densenet201.weights", 0)
  115. #im = load_image("data/wolf.jpg", 0, 0)
  116. #meta = load_meta("cfg/imagenet1k.data")
  117. #r = classify(net, meta, im)
  118. #print r[:10]
  119. net = load_net("cfg/tiny-yolo.cfg", "tiny-yolo.weights", 0)
  120. meta = load_meta("cfg/coco.data")
  121. r = detect(net, meta, "data/dog.jpg")
  122. print r