mnist數據集是一個很經典的手寫數據集,經常用於視覺神經網絡的教學,几乎是每個接觸AI的人都會嘗試識別的數據集。
以下代碼是我詢問了GPT之後,適當修改后得到的代碼,在此備忘。
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
# Step 1: 导入所需的库和模块
# Step 2: 加载和预处理数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train / 255.0
x_test = x_test / 255.0
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
# Step 3: 构建模型
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
# Step 4: 编译模型
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
# Step 5: 训练模型
model.fit(x_train, y_train, epochs=10, batch_size=32)
# Step 6: 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)
# Step 7: 使用模型进行预测
predictions = model.predict(x_test)
得到預測值之後,使用plt打印可視化的預測結果。請使用JupyterNotebook,否則無法顯示圖片。
predicted_labels = np.argmax(predictions, axis=1)
# 打印图像及其预测结果
plt.figure(figsize=(10, 10))
for i in range(25):
plt.subplot(5, 5, i+1)
plt.imshow(x_test[i], cmap='gray')
plt.title(f"Pre: {predicted_labels[i]}")
plt.axis('off')
plt.show()