matrix_fun.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # -*- coding:utf-8 -*-
  2. import numpy as np
  3. import face_mysql
  4. class matrix:
  5. def __init__(self):
  6. pass
  7. # 两个矩阵的欧式距离
  8. def EuclideanDistances(self, A, B):
  9. BT = B.transpose()
  10. # vecProd = A * BT
  11. vecProd = np.dot(A, BT)
  12. # print(vecProd)
  13. SqA = A ** 2
  14. # print(SqA)
  15. sumSqA = np.matrix(np.sum(SqA, axis=1))
  16. sumSqAEx = np.tile(sumSqA.transpose(), (1, vecProd.shape[1]))
  17. # print(sumSqAEx)
  18. SqB = B ** 2
  19. sumSqB = np.sum(SqB, axis=1)
  20. sumSqBEx = np.tile(sumSqB, (vecProd.shape[0], 1))
  21. SqED = sumSqBEx + sumSqAEx - 2 * vecProd
  22. SqED[SqED < 0] = 0.0
  23. ED = np.sqrt(SqED)
  24. return ED.transpose()
  25. #
  26. def get_socres(self, A, ugroup):
  27. # 设置每次处理的最大数据库记录数
  28. #如果数据库中记录太多时可分批进行处理
  29. maxlen = 128
  30. fmysql = face_mysql.face_mysql()
  31. results = np.array(fmysql.findall_facejson(ugroup))
  32. #如果没有数据库找到时 直接返回空的list
  33. if results.shape[0] == 0: return [],[],[]
  34. pic_scores_all = []
  35. #获取数据库中的入库时的图片名称 pic_names在数据库中存的是数组列索引4这个位置
  36. pic_names = results[:, 4]
  37. #获取入库时图片对象的uid pic_uid在数据库中存的是数组列索引2这个位置
  38. pic_uid = results[:, 2]
  39. for i in range(0, len(results), maxlen):
  40. pic_vectors = results[i:i + maxlen, 3]
  41. # 效率待优化,现在是每行处理
  42. pic_vectors = [[float(j) for j in i.split(',')] for i in pic_vectors]
  43. pic_socores = self.EuclideanDistances(A, np.array(pic_vectors))
  44. pic_socores_list = np.array(pic_socores).tolist()
  45. pic_scores_all.extend(pic_socores_list)
  46. pic_scores_all = np.array(pic_scores_all).transpose()
  47. # 获取距离最近的值
  48. # np.argsort() 返回排序后的索引
  49. pic_min_scores = np.amin(pic_scores_all, axis=1)
  50. pic_min_names = []
  51. pic_min_uid = []
  52. for i in range(0, len(pic_min_scores)):
  53. # 获取最小值的index
  54. index = np.where(pic_scores_all[i] == pic_min_scores[i])
  55. # print(int(index[0]))
  56. # 有多个符合条件的只取第一个
  57. pic_min_names.append(pic_names[index[0][0]])
  58. pic_min_uid.append(pic_uid[index[0][0]])
  59. # print(pic_min_names)
  60. return pic_min_scores.tolist(), pic_min_names, pic_min_uid