k_means_cluster.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. #!/usr/bin/python
  2. """
  3. Skeleton code for k-means clustering mini-project.
  4. """
  5. import pickle
  6. import numpy
  7. import matplotlib.pyplot as plt
  8. import sys
  9. sys.path.append("../tools/")
  10. from feature_format import featureFormat, targetFeatureSplit
  11. def Draw(pred, features, poi, mark_poi=False, name="image.png", f1_name="feature 1", f2_name="feature 2"):
  12. """ some plotting code designed to help you visualize your clusters """
  13. ### plot each cluster with a different color--add more colors for
  14. ### drawing more than five clusters
  15. colors = ["b", "c", "k", "m", "g"]
  16. for ii, pp in enumerate(pred):
  17. plt.scatter(features[ii][0], features[ii][1], color = colors[pred[ii]])
  18. ### if you like, place red stars over points that are POIs (just for funsies)
  19. if mark_poi:
  20. for ii, pp in enumerate(pred):
  21. if poi[ii]:
  22. plt.scatter(features[ii][0], features[ii][1], color="r", marker="*")
  23. plt.xlabel(f1_name)
  24. plt.ylabel(f2_name)
  25. plt.savefig(name)
  26. plt.show()
  27. ### load in the dict of dicts containing all the data on each person in the dataset
  28. data_dict = pickle.load( open("../final_project/final_project_dataset.pkl", "r") )
  29. ### there's an outlier--remove it!
  30. data_dict.pop("TOTAL", 0)
  31. ### the input features we want to use
  32. ### can be any key in the person-level dictionary (salary, director_fees, etc.)
  33. feature_1 = "salary"
  34. feature_2 = "exercised_stock_options"
  35. poi = "poi"
  36. features_list = [poi, feature_1, feature_2]
  37. data = featureFormat(data_dict, features_list )
  38. poi, finance_features = targetFeatureSplit( data )
  39. ### in the "clustering with 3 features" part of the mini-project,
  40. ### you'll want to change this line to
  41. ### for f1, f2, _ in finance_features:
  42. ### (as it's currently written, the line below assumes 2 features)
  43. for f1, f2 in finance_features:
  44. plt.scatter( f1, f2 )
  45. plt.show()
  46. ### cluster here; create predictions of the cluster labels
  47. ### for the data and store them to a list called pred
  48. ### rename the "name" parameter when you change the number of features
  49. ### so that the figure gets saved to a different file
  50. try:
  51. Draw(pred, finance_features, poi, mark_poi=False, name="clusters.pdf", f1_name=feature_1, f2_name=feature_2)
  52. except NameError:
  53. print "no predictions object named pred found, no clusters to plot"