Mediapipe简介
Mediapipe 是由 Google Research 开发的一款开源框架,旨在帮助开发者轻松地构建、测试和部署复杂的多模态、多任务的机器学习模型。它特别擅长于实时处理和分析音频、视频等多媒体数据。以下是 Mediapipe 的一些关键特点和组件:
关键特点
多平台支持:Mediapipe 支持在多个平台上运行,包括桌面、移动设备和网页。这使得开发者可以轻松地将模型部署到不同的平台上。
高效的实时处理:Mediapipe 具有高度优化的性能,能够在资源受限的设备上进行实时处理。这使其特别适合于移动设备和嵌入式系统。
模块化设计:Mediapipe 使用图表(graph)来组织和连接不同的处理模块。这种设计使得开发者可以灵活地组合和复用不同的处理组件。
丰富的预构建解决方案:Mediapipe 提供了许多预构建的解决方案,如人脸检测、手部追踪、姿态估计等,开发者可以直接使用这些解决方案来快速构建应用。
主要组件
图表(Graph):Mediapipe 的核心是其图表结构,图表定义了数据流和处理模块的连接方式。每个图表由一系列节点(nodes)和边(edges)组成,节点表示具体的处理模块,边表示数据在节点之间的流动。
节点(Nodes):节点是图表的基本单元,表示具体的处理操作。Mediapipe 提供了许多内置的节点,如数据输入输出节点、图像处理节点、机器学习推理节点等。
数据包(Packets):数据包是图表中传输的数据单元,节点之间通过发送和接收数据包来通信。数据包可以包含各种类型的数据,如图像帧、音频信号、检测结果等。
计算机视觉解决方案:Mediapipe 提供了许多预构建的计算机视觉解决方案,这些解决方案已经高度优化,能够在实时应用中使用。常见的解决方案包括人脸检测、手部追踪、姿态估计、对象检测等。
常见使用场景
姿态估计(Pose Estimation):Mediapipe 可以实时检测和追踪人体的关键点(如肩膀、肘部、膝盖等),并估计人体的姿态。这对于体育训练、动作捕捉、增强现实等应用非常有用。
手部追踪(Hand Tracking):Mediapipe 能够检测和追踪手部的关键点,提供手势识别和手部动作分析的能力。这在手势控制、虚拟现实、手写输入等应用中有广泛的应用。
人脸检测(Face Detection):Mediapipe 提供了高效的人脸检测和关键点追踪功能,可以用于面部识别、表情分析、虚拟化妆等场景。
对象检测(Object Detection):Mediapipe 还提供了实时的对象检测解决方案,可以用于监控、无人驾驶、智能家居等领域。
示例代码
以下是一个使用 Mediapipe 进行姿态估计的简单示例:
import cv2
import mediapipe as mp
# Initialize mediapipe pose class.
mp_pose = mp.solutions.pose
pose = mp_pose.Pose()
# Initialize mediapipe drawing class, useful for annotation.
mp_drawing = mp.solutions.drawing_utils
# Load the video file or start webcam capture.
cap = cv2.VideoCapture(0) # Use 0 for webcam, or provide video file path
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Convert the BGR image to RGB.
image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Process the image and detect the pose.
result = pose.process(image_rgb)
# Draw the pose annotation on the image.
if result.pose_landmarks:
mp_drawing.draw_landmarks(frame, result.pose_landmarks, mp_pose.POSE_CONNECTIONS)
# Display the frame with pose landmarks.
cv2.imshow('Pose Estimation', frame)
# Break the loop if 'q' is pressed.
if cv2.waitKey(10) & 0xFF == ord('q'):
break
# Release the video capture object and close display window.
cap.release()
cv2.destroyAllWindows()
这段代码使用 Mediapipe 的姿态估计功能,读取视频流并实时绘制人体的关键点。你可以使用摄像头实时捕捉人体姿态,也可以处理预录制的视频文件。
实例1-读取视频流并进行骨骼点绘制:
import cv2
import mediapipe as mp
# Initialize mediapipe pose class.
mp_pose = mp.solutions.pose
pose = mp_pose.Pose()
# Initialize mediapipe drawing class, useful for annotation.
mp_drawing = mp.solutions.drawing_utils
# Load the video file.
cap = cv2.VideoCapture('D:/basketball.mp4')
# Check if the video is opened successfully.
if not cap.isOpened():
print("Error: Could not open video.")
exit()
while cap.isOpened():
ret, frame = cap.read()
if not ret:
print("Reached the end of the video.")
break
# Convert the BGR image to RGB.
image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Process the image and detect the pose.
result = pose.process(image_rgb)
# Draw the pose annotation on the image.
if result.pose_landmarks:
mp_drawing.draw_landmarks(frame, result.pose_landmarks, mp_pose.POSE_CONNECTIONS)
# Display the frame with pose landmarks.
cv2.imshow('Pose Estimation', frame)
# Break the loop if 'q' is pressed.
if cv2.waitKey(10) & 0xFF == ord('q'):
break
# Release the video capture object and close display window.
cap.release()
cv2.destroyAllWindows()
效果如下:
实例2-读取视频流中姿态估计与3D绘制:
代码如下:
import cv2
import mediapipe as mp
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from matplotlib.animation import FuncAnimation
# Initialize mediapipe pose class.
mp_pose = mp.solutions.pose
pose = mp_pose.Pose()
# Initialize mediapipe drawing class, useful for annotation.
mp_drawing = mp.solutions.drawing_utils
# Load the video file.
cap = cv2.VideoCapture('D:/basketball.mp4')
# Check if the video is opened successfully.
if not cap.isOpened():
print("Error: Could not open video.")
exit()
fig = plt.figure(figsize=(10, 5))
ax2d = fig.add_subplot(121)
ax3d = fig.add_subplot(122, projection='3d')
def update(frame_number):
ret, frame = cap.read()
if not ret:
print("Reached the end of the video.")
return
# Convert the BGR image to RGB.
image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Process the image and detect the pose.
result = pose.process(image_rgb)
# Clear the previous plots
ax2d.clear()
ax3d.clear()
# Draw the pose annotation on the image.
if result.pose_landmarks:
mp_drawing.draw_landmarks(frame, result.pose_landmarks, mp_pose.POSE_CONNECTIONS)
# Extract the landmark points.
landmarks = result.pose_landmarks.landmark
xs = [landmark.x for landmark in landmarks]
ys = [landmark.y for landmark in landmarks]
zs = [landmark.z for landmark in landmarks]
# Plot 2D image
ax2d.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
ax2d.set_title('Pose Estimation')
ax2d.axis('off')
# Plot 3D landmarks
ax3d.scatter(xs, ys, zs, c='blue', marker='o')
ax3d.set_xlim([0, 1])
ax3d.set_ylim([0, 1])
ax3d.set_zlim([-1, 1])
ax3d.set_xlabel('X')
ax3d.set_ylabel('Y')
ax3d.set_zlabel('Z')
ax3d.set_title('3D Pose Landmarks')
ani = FuncAnimation(fig, update, interval=10)
plt.show()
cap.release()
cv2.destroyAllWindows()
效果如下:
为了将三维骨骼点连接起来,可以使用 mpl_toolkits.mplot3d.art3d.Line3DCollection
来绘制骨骼连接。你需要定义这些连接的点对,并在三维图中使用它们来绘制线条。以下是更新后的代码:
import cv2
import mediapipe as mp
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Line3DCollection
import numpy as np
from matplotlib.animation import FuncAnimation
# Initialize mediapipe pose class.
mp_pose = mp.solutions.pose
pose = mp_pose.Pose()
# Initialize mediapipe drawing class, useful for annotation.
mp_drawing = mp.solutions.drawing_utils
# Load the video file.
cap = cv2.VideoCapture('D:/basketball.mp4')
# Check if the video is opened successfully.
if not cap.isOpened():
print("Error: Could not open video.")
exit()
fig = plt.figure(figsize=(10, 5))
ax2d = fig.add_subplot(121)
ax3d = fig.add_subplot(122, projection='3d')
def update(frame_number):
ret, frame = cap.read()
if not ret:
print("Reached the end of the video.")
return
# Convert the BGR image to RGB.
image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Process the image and detect the pose.
result = pose.process(image_rgb)
# Clear the previous plots
ax2d.clear()
ax3d.clear()
# Draw the pose annotation on the image.
if result.pose_landmarks:
mp_drawing.draw_landmarks(frame, result.pose_landmarks, mp_pose.POSE_CONNECTIONS)
# Extract the landmark points.
landmarks = result.pose_landmarks.landmark
xs = [landmark.x for landmark in landmarks]
ys = [landmark.y for landmark in landmarks]
zs = [landmark.z for landmark in landmarks]
# Define the connections between landmarks
connections = [
(0, 1), (1, 2), (2, 3), (3, 7), (0, 4), (4, 5), (5, 6), (6, 8),
(9, 10), (11, 12), (11, 13), (13, 15), (15, 17), (15, 19), (15, 21),
(17, 19), (12, 14), (14, 16), (16, 18), (16, 20), (16, 22), (18, 20),
(11, 23), (12, 24), (23, 24), (23, 25), (24, 26), (25, 27), (26, 28),
(27, 29), (28, 30), (29, 31), (30, 32)
]
# Create a list of 3D lines
lines = [[(xs[start], ys[start], zs[start]), (xs[end], ys[end], zs[end])] for start, end in connections]
# Plot 2D image
ax2d.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
ax2d.set_title('Pose Estimation')
ax2d.axis('off')
# Plot 3D landmarks and connections
ax3d.scatter(xs, ys, zs, c='blue', marker='o')
ax3d.add_collection3d(Line3DCollection(lines, colors='blue', linewidths=2))
ax3d.set_xlim([0, 1])
ax3d.set_ylim([0, 1])
ax3d.set_zlim([-1, 1])
ax3d.set_xlabel('X')
ax3d.set_ylabel('Y')
ax3d.set_zlabel('Z')
ax3d.set_title('3D Pose Landmarks')
ani = FuncAnimation(fig, update, interval=10)
plt.show()
cap.release()
cv2.destroyAllWindows()
在这个代码中,我们定义了 connections
列表,它包含了骨骼点之间的连接对。然后我们创建了一个 lines
列表,用于存储这些连接的三维线段,并使用 ax3d.add_collection3d(Line3DCollection(lines, colors='blue', linewidths=2))
方法将这些线段添加到三维图中。
运行这个脚本后,三维图中不仅会显示骨骼点,还会将这些点连起来,形成完整的骨骼结构。
效果如下:
上面的代码看似三维图的骨骼是倒立的,你可以调整三维图的坐标显示,以使得骨骼结构显示为正常的人体姿态。可以通过设置三维图的坐标轴范围和方向来调整显示效果。以下是修改后的代码,调整了坐标轴的范围和方向,以使骨骼结构正常显示:
import cv2
import mediapipe as mp
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Line3DCollection
import numpy as np
from matplotlib.animation import FuncAnimation
# Initialize mediapipe pose class.
mp_pose = mp.solutions.pose
pose = mp_pose.Pose()
# Initialize mediapipe drawing class, useful for annotation.
mp_drawing = mp.solutions.drawing_utils
# Load the video file.
cap = cv2.VideoCapture('D:/basketball.mp4')
# Check if the video is opened successfully.
if not cap.isOpened():
print("Error: Could not open video.")
exit()
fig = plt.figure(figsize=(10, 5))
ax2d = fig.add_subplot(121)
ax3d = fig.add_subplot(122, projection='3d')
def update(frame_number):
ret, frame = cap.read()
if not ret:
print("Reached the end of the video.")
return
# Convert the BGR image to RGB.
image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Process the image and detect the pose.
result = pose.process(image_rgb)
# Clear the previous plots
ax2d.clear()
ax3d.clear()
# Draw the pose annotation on the image.
if result.pose_landmarks:
mp_drawing.draw_landmarks(frame, result.pose_landmarks, mp_pose.POSE_CONNECTIONS)
# Extract the landmark points.
landmarks = result.pose_landmarks.landmark
xs = [landmark.x for landmark in landmarks]
ys = [landmark.y for landmark in landmarks]
zs = [-landmark.z for landmark in landmarks] # Negate the z-axis for better visualization
# Define the connections between landmarks
connections = [
(0, 1), (1, 2), (2, 3), (3, 7), (0, 4), (4, 5), (5, 6), (6, 8),
(9, 10), (11, 12), (11, 13), (13, 15), (15, 17), (15, 19), (15, 21),
(17, 19), (12, 14), (14, 16), (16, 18), (16, 20), (16, 22), (18, 20),
(11, 23), (12, 24), (23, 24), (23, 25), (24, 26), (25, 27), (26, 28),
(27, 29), (28, 30), (29, 31), (30, 32)
]
# Create a list of 3D lines
lines = [[(xs[start], ys[start], zs[start]), (xs[end], ys[end], zs[end])] for start, end in connections]
# Plot 2D image
ax2d.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
ax2d.set_title('Pose Estimation')
ax2d.axis('off')
# Plot 3D landmarks and connections
ax3d.scatter(xs, ys, zs, c='blue', marker='o')
ax3d.add_collection3d(Line3DCollection(lines, colors='blue', linewidths=2))
ax3d.set_xlim([0, 1])
ax3d.set_ylim([1, 0]) # Flip the y-axis for better visualization
ax3d.set_zlim([1, -1])
ax3d.set_xlabel('X')
ax3d.set_ylabel('Y')
ax3d.set_zlabel('Z')
ax3d.set_title('3D Pose Landmarks')
ani = FuncAnimation(fig, update, interval=10)
plt.show()
cap.release()
cv2.destroyAllWindows()
在这个代码中:
- 通过取反
zs
坐标 (zs = [-landmark.z for landmark in landmarks]
),使得骨骼点的 Z 轴方向与预期一致。 - 通过设置
ax3d.set_ylim([1, 0])
来翻转 Y 轴的方向,以便更符合常见的视觉习惯。
运行这个脚本后,三维图中的骨骼结构应会显示为正常的人体姿态。
显示效果如下: