杯子茶室

关注有趣的事物

使用Tensorflow庫中的mnist手寫數字數據集訓練自己的第一個視覺模型

网络 0 评 66 度

mnist數據集是一個很經典的手寫數據集,經常用於視覺神經網絡的教學,几乎是每個接觸AI的人都會嘗試識別的數據集。

以下代碼是我詢問了GPT之後,適當修改后得到的代碼,在此備忘。

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

# Step 1: 导入所需的库和模块

# Step 2: 加载和预处理数据集

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = x_train / 255.0

x_test = x_test / 255.0

y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

# Step 3: 构建模型

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

# Step 4: 编译模型

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Step 5: 训练模型

model.fit(x_train, y_train, epochs=10, batch_size=32)

# Step 6: 评估模型

test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)

# Step 7: 使用模型进行预测

predictions = model.predict(x_test)

得到預測值之後,使用plt打印可視化的預測結果。請使用JupyterNotebook,否則無法顯示圖片。

predicted_labels = np.argmax(predictions, axis=1)

# 打印图像及其预测结果
plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i+1)
    plt.imshow(x_test[i], cmap='gray')
    plt.title(f"Pre: {predicted_labels[i]}")
    plt.axis('off')
plt.show()
解决Manjaro在睡眠(or休眠)唤醒之后触摸板失灵的解决方法
发表评论
撰写评论