首页 > Python资料 博客日记
tensorflow + pygame 手写数字识别的小游戏
2024-10-14 21:00:11Python资料围观56次
这篇文章介绍了tensorflow + pygame 手写数字识别的小游戏,分享给大家做个参考,收藏Python资料网收获更多编程知识
起因, 目的:
很久之前,一个客户的作业,我帮忙写的。
今天删项目,觉得比较简洁,发出来给大家看看。
效果图:
1. 训练模型的代码
import sys
import tensorflow as tf
# Use MNIST handwriting dataset
mnist = tf.keras.datasets.mnist
# Prepare data for training
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train = tf.keras.utils.to_categorical(y_train)
y_test = tf.keras.utils.to_categorical(y_test)
x_train = x_train.reshape(
x_train.shape[0], x_train.shape[1], x_train.shape[2], 1
)
x_test = x_test.reshape(
x_test.shape[0], x_test.shape[1], x_test.shape[2], 1
)
"""
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
"""
# Create a convolutional neural network
model = tf.keras.models.Sequential([
# 1. Convolutional layer. Learn 32 filters using a 3x3 kernel, activation function is relu, input shape (28,28,1)
tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
#2. Max-pooling layer, using 2x2 pool size
tf.keras.layers.MaxPooling2D((2, 2)),
#3. Flatten units
tf.keras.layers.Flatten(),
#4. Add a hidden layer with dropout,
tf.keras.layers.Dropout(0.2),
#5. Add an output layer with output units for all 10 digits, activation function is softmax
tf.keras.layers.Dense(10, activation='softmax')
])
# Train neural network
model.compile(
optimizer="adam",
loss="categorical_crossentropy",
metrics=["accuracy"]
)
model.fit(x_train, y_train, epochs=10)
# Evaluate neural network performance
model.evaluate(x_test, y_test, verbose=2)
# Save model to file
if len(sys.argv) == 2:
filename = sys.argv[1]
model.save(filename)
print(f"Model saved to {filename}.")
"""
Run this code: python handwriting.py model_1.pth
output:
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0413 - accuracy: 0.9873
Epoch 8/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0385 - accuracy: 0.9877
Epoch 9/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0338 - accuracy: 0.9898
Epoch 10/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0319 - accuracy: 0.9900
313/313 - 1s - loss: 0.0511 - accuracy: 0.9845 - 718ms/epoch - 2ms/step
Model saved to model_1.pth.
"""
2. 运行小游戏, 进行识别
从命令行运行:
python recognition.py model.h5
import numpy as np
import pygame
import sys
import tensorflow as tf
import time
"""
run this code:
python recognition.py model_1.pth
or,
python recognition.py model.h5
output:
"""
print("len(sys.argv): ", len(sys.argv))
# Check command-line arguments
if len(sys.argv) != 2:
print("Usage: python recognition.py model")
sys.exit()
model = tf.keras.models.load_model(sys.argv[1])
# Colors
BLACK = (0, 0, 0)
WHITE = (255, 255, 255)
# Start pygame
pygame.init()
size = width, height = 600, 400
screen = pygame.display.set_mode(size)
# Fonts
OPEN_SANS = "assets/fonts/OpenSans-Regular.ttf"
smallFont = pygame.font.Font(OPEN_SANS, 20)
largeFont = pygame.font.Font(OPEN_SANS, 40)
ROWS, COLS = 28, 28
OFFSET = 20
CELL_SIZE = 10
handwriting = [[0] * COLS for _ in range(ROWS)]
classification = None
while True:
# Check if game quit
for event in pygame.event.get():
if event.type == pygame.QUIT:
sys.exit()
screen.fill(BLACK)
# Check for mouse press
click, _, _ = pygame.mouse.get_pressed()
if click == 1:
mouse = pygame.mouse.get_pos()
else:
mouse = None
# Draw each grid cell
cells = []
for i in range(ROWS):
row = []
for j in range(COLS):
rect = pygame.Rect(
OFFSET + j * CELL_SIZE,
OFFSET + i * CELL_SIZE,
CELL_SIZE, CELL_SIZE
)
# If cell has been written on, darken cell
if handwriting[i][j]:
channel = 255 - (handwriting[i][j] * 255)
pygame.draw.rect(screen, (channel, channel, channel), rect)
# Draw blank cell
else:
pygame.draw.rect(screen, WHITE, rect)
pygame.draw.rect(screen, BLACK, rect, 1)
# If writing on this cell, fill in current cell and neighbors
if mouse and rect.collidepoint(mouse):
handwriting[i][j] = 250 / 255
if i + 1 < ROWS:
handwriting[i + 1][j] = 220 / 255
if j + 1 < COLS:
handwriting[i][j + 1] = 220 / 255
if i + 1 < ROWS and j + 1 < COLS:
handwriting[i + 1][j + 1] = 190 / 255
# Reset button
resetButton = pygame.Rect(
30, OFFSET + ROWS * CELL_SIZE + 30,
100, 30
)
resetText = smallFont.render("Reset", True, BLACK)
resetTextRect = resetText.get_rect()
resetTextRect.center = resetButton.center
pygame.draw.rect(screen, WHITE, resetButton)
screen.blit(resetText, resetTextRect)
# Classify button
classifyButton = pygame.Rect(
150, OFFSET + ROWS * CELL_SIZE + 30,
100, 30
)
classifyText = smallFont.render("Classify", True, BLACK)
classifyTextRect = classifyText.get_rect()
classifyTextRect.center = classifyButton.center
pygame.draw.rect(screen, WHITE, classifyButton)
screen.blit(classifyText, classifyTextRect)
# Reset drawing
if mouse and resetButton.collidepoint(mouse):
handwriting = [[0] * COLS for _ in range(ROWS)]
classification = None
# Generate classification
if mouse and classifyButton.collidepoint(mouse):
classification = model.predict(
[np.array(handwriting).reshape(1, 28, 28, 1)]
).argmax()
# Show classification if one exists
if classification is not None:
classificationText = largeFont.render(str(classification), True, WHITE)
classificationRect = classificationText.get_rect()
grid_size = OFFSET * 2 + CELL_SIZE * COLS
classificationRect.center = (
grid_size + ((width - grid_size) / 2),
100
)
screen.blit(classificationText, classificationRect)
pygame.display.flip()
完整项目,我已经上传了。 0积分下载。
完整项目链接
https://download.csdn.net/download/waterHBO/89881853
老哥留步,支持一下。
版权声明:本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:jacktools123@163.com进行投诉反馈,一经查实,立即删除!
标签:
相关文章
最新发布
- 光流法结合深度学习神经网络的原理及应用(完整代码都有Python opencv)
- Python 图像处理进阶:特征提取与图像分类
- 大数据可视化分析-基于python的电影数据分析及可视化系统_9532dr50
- 【Python】入门(运算、输出、数据类型)
- 【Python】第一弹---解锁编程新世界:深入理解计算机基础与Python入门指南
- 华为OD机试E卷 --第k个排列 --24年OD统一考试(Java & JS & Python & C & C++)
- Python已安装包在import时报错未找到的解决方法
- 【Python】自动化神器PyAutoGUI —告别手动操作,一键模拟鼠标键盘,玩转微信及各种软件自动化
- Pycharm连接SQL Sever(详细教程)
- Python编程练习题及解析(49题)
点击排行
- 版本匹配指南:Numpy版本和Python版本的对应关系
- 版本匹配指南:PyTorch版本、torchvision 版本和Python版本的对应关系
- Python 可视化 web 神器:streamlit、Gradio、dash、nicegui;低代码 Python Web 框架:PyWebIO
- 相关性分析——Pearson相关系数+热力图(附data和Python完整代码)
- Anaconda版本和Python版本对应关系(持续更新...)
- Python与PyTorch的版本对应
- Windows上安装 Python 环境并配置环境变量 (超详细教程)
- Python pyinstaller打包exe最完整教程