import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.manifold import TSNE
mnist = fetch_openml("mnist_784")
x = mnist.data / 255.0
y = mnist.target.astype(int)
tsne = TSNE(n_components=2, n_jobs=-1)
X_tsne = tsne.fit_transform(x)
plt.figure()
for i in range(10):
plt.scatter(X_tsne[y == i, 0], X_tsne[y == i, 1], s=1, label=str(i))
plt.title("t-SNE Visualization of MNIST Dataset")
plt.xlabel("Component 1")
plt.ylabel("Component 2")
plt.legend(title="Digit")
plt.grid(True)
plt.show()
代码结果如下图: