1234567891011121314151617181920212223242526272829303132 |
- import matplotlib
- from utils.file.trans_methods import create_file_path
- matplotlib.use('Agg')
- matplotlib.rcParams['font.family'] = 'SimHei'
- matplotlib.rcParams['font.sans-serif'] = ['SimHei']
- matplotlib.rcParams['axes.unicode_minus'] = False
- from matplotlib import pyplot as plt
- def scatter(title, x_label, y_label, x_values, y_values, color=None, col_map=dict(), size=10,
- save_file_path=''):
- if save_file_path:
- create_file_path(save_file_path, True)
- else:
- save_file_path = title + '.png'
- plt.figure(figsize=(8, 6))
- plt.title(title, fontsize=16)
- plt.xlabel(x_label, fontsize=14)
- plt.ylabel(y_label, fontsize=14)
- if color is not None:
- plt.scatter(x_values, y_values, s=size, c=color)
- if col_map:
- patches = [plt.Rectangle((0, 0), 1, 1, fc=c) for c in col_map.values()]
- plt.legend(patches, list(col_map.keys()))
- else:
- plt.scatter(x_values, y_values, s=size)
- plt.savefig(save_file_path)
- plt.close()
|