import matplotlib.pyplot as plt

from sklearn.datasets import fetch_openml
from sklearn.manifold import TSNE

mnist = fetch_openml("mnist_784")
x = mnist.data / 255.0
y = mnist.target.astype(int)

tsne = TSNE(n_components=2, n_jobs=-1)
X_tsne = tsne.fit_transform(x)

plt.figure()
for i in range(10):
plt.scatter(X_tsne[y == i, 0], X_tsne[y == i, 1], s=1, label=str(i))

plt.title("t-SNE Visualization of MNIST Dataset")
plt.xlabel("Component 1")
plt.ylabel("Component 2")
plt.legend(title="Digit")
plt.grid(True)
plt.show()

代码结果如下图:

最后修改:2025 年 04 月 17 日 12 : 37 AM
如果觉得我的文章对你有用,请随意赞赏