← 返回首页
Google-mediapipe姿态跟踪框架简介
发表时间:2024-08-25 04:43:39
Google-mediapipe姿态跟踪框架简介

Mediapipe 是谷歌出品的一种开源框架,旨在为开发者提供一种简单而强大的工具,用于实现各种视觉和感知应用程序。它包括一系列预训练的机器学习模型和用于处理多媒体数据的工具,可以用于姿势估计、手部追踪、人脸检测与跟踪、面部标志、对象检测、图片分割和语言检测等任务。Mediapipe 是支持跨平台的,可以部署在手机端(Android, iOS), web, desktop, edge devices, IoT 等各种平台,编程语言也支持C++, Python, Java, Swift, Objective-C, Javascript等主流语言。

1.MediaPipe概述

MediaPipe 是一款由 Google Research 开发并开源的多媒体机器学习模型应用框架。 github 地址:https://github.com/google/mediapipe

MediaPipe目前支持的解决方案(Solution)及支持的平台如下图所示:

2.姿势估计

序号  部位          Pose Landmark
--------------------------------------------------
0   鼻子          PoseLandmark.NOSE
1   左眼(内侧)  PoseLandmark.LEFT_EYE_INNER
2   左眼          PoseLandmark.LEFT_EYE
3   左眼(外侧)  PoseLandmark.LEFT_EYE_OUTER
4   右眼(内侧)  PoseLandmark.RIGHT_EYE_INNER
5   右眼          PoseLandmark.RIGHT_EYE
6   右眼(外侧)  PoseLandmark.RIGHT_EYE_OUTER
7   左耳          PoseLandmark.LEFT_EAR
8   右耳          PoseLandmark.RIGHT_EAR
9   嘴巴(左侧)  PoseLandmark.MOUTH_LEFT
10  嘴巴(右侧)  PoseLandmark.MOUTH_RIGHT
11  左肩          PoseLandmark.LEFT_SHOULDER
12  右肩          PoseLandmark.RIGHT_SHOULDER
13  左肘          PoseLandmark.LEFT_ELBOW
14  右肘          PoseLandmark.RIGHT_ELBOW
15  左腕          PoseLandmark.LEFT_WRIST
16  右腕          PoseLandmark.RIGHT_WRIST
17  左小指         PoseLandmark.LEFT_PINKY
18  右小指         PoseLandmark.RIGHT_PINKY
19  左食指         PoseLandmark.LEFT_INDEX
20  右食指         PoseLandmark.RIGHT_INDEX
21  左拇指         PoseLandmark.LEFT_THUMB
22  右拇指         PoseLandmark.RIGHT_THUMB
23  左臀          PoseLandmark.LEFT_HIP
24  右臀          PoseLandmark.RIGHT_HIP
25  左膝          PoseLandmark.LEFT_KNEE
26  右膝          PoseLandmark.RIGHT_KNEE
27  左踝          PoseLandmark.LEFT_ANKLE
28  右踝          PoseLandmark.RIGHT_ANKLE
29  左脚跟         PoseLandmark.LEFT_HEEL
30  右脚跟         PoseLandmark.RIGHT_HEEL
31  左脚趾         PoseLandmark.LEFT_FOOT_INDEX
32  右脚趾         PoseLandmark.RIGHT_FOOT_INDEX

新版 solution API,旧版 API 并不能检测多个姿态,新版 API 可以实现多个姿态检测。

姿态估计实例。

1.检测单人姿态

安装依赖。

pip install mediapipe
import cv2
import numpy as np
import mediapipe as mp

def main():
    FILE_PATH = 'images/dance_girl_000.jpg'
    img = cv2.imread(FILE_PATH)
    mp_pose = mp.solutions.pose
    pose = mp_pose.Pose(static_image_mode=True,
                        min_detection_confidence=0.5, min_tracking_confidence=0.5)

    res = pose.process(img)
    img_copy = img.copy()

    if res.pose_landmarks is not None:
        mp_drawing = mp.solutions.drawing_utils
        # mp_drawing.draw_landmarks(
        #     img_copy, res.pose_landmarks, mp.solutions.pose.POSE_CONNECTIONS)
        mp_drawing.draw_landmarks(
            img_copy,
            res.pose_landmarks,
            mp_pose.POSE_CONNECTIONS,  # frozenset,定义了哪些关键点要连接
            mp_drawing.DrawingSpec(color=(255, 255, 255),  # 姿态关键点
                                   thickness=2,
                                   circle_radius=2),
            mp_drawing.DrawingSpec(color=(174, 139, 45),   # 连线颜色
                                   thickness=2,
                                   circle_radius=2),
        )

    cv2.imshow('MediaPipe Pose Estimation', img_copy)
    cv2.waitKey(0)

if __name__ == '__main__':
    main()

运行效果:

2.检测多人姿态

from mediapipe import solutions
from mediapipe.framework.formats import landmark_pb2
import cv2
import numpy as np
import mediapipe as mp

mp_drawing = mp.solutions.drawing_utils
mp_pose = mp.solutions.pose

def draw_landmarks_on_image(rgb_image, detection_result):
    pose_landmarks_list = detection_result.pose_landmarks
    annotated_image = np.copy(rgb_image)

    # Loop through the detected poses to visualize.
    for idx in range(len(pose_landmarks_list)):
        pose_landmarks = pose_landmarks_list[idx]

        # Draw the pose landmarks.
        pose_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
        pose_landmarks_proto.landmark.extend([
            landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in pose_landmarks
        ])
        solutions.drawing_utils.draw_landmarks(
            annotated_image,
            pose_landmarks_proto,
            solutions.pose.POSE_CONNECTIONS,
            solutions.drawing_styles.get_default_pose_landmarks_style())
    return annotated_image

#
'''
https://storage.googleapis.com/mediapipe-models/hand_landmarker/hand_landmarker/float16/latest/hand_landmarker.task.
'''
#
def newSolution():
    BaseOptions = mp.tasks.BaseOptions
    PoseLandmarker = mp.tasks.vision.PoseLandmarker
    PoseLandmarkerOptions = mp.tasks.vision.PoseLandmarkerOptions
    VisionRunningMode = mp.tasks.vision.RunningMode
    model_path = 'model/pose_landmarker_heavy.task'
    options = PoseLandmarkerOptions(
        base_options=BaseOptions(model_asset_path=model_path),
        running_mode=VisionRunningMode.IMAGE,
        num_poses=10)

    FILE_PATH = 'images/dance_girl_001.jpg'
    image = cv2.imread(FILE_PATH)
    img = mp.Image.create_from_file(FILE_PATH)
    with PoseLandmarker.create_from_options(options) as detector:
        res = detector.detect(img)
        image = draw_landmarks_on_image(image, res)

    cv2.imshow('MediaPipe Pose Estimation', image)
    cv2.waitKey(0)


if __name__ == '__main__':
    newSolution()

运行效果:

3.检测视频姿态

import cv2
import numpy as np
import mediapipe as mp


def video():
    # 读取摄像头
    # cap = cv2.VideoCapture(0)
    # 读取视频
    cap = cv2.VideoCapture('video/basketball_player_01.mp4')
    mp_pose = mp.solutions.pose
    pose = mp_pose.Pose(static_image_mode=False,
                        min_detection_confidence=0.5, min_tracking_confidence=0.5)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
            # 摄像头
            # continue

        # 将 BGR 图像转换为 RGB
        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # 进行姿势估计
        results = pose.process(rgb_frame)

        if results.pose_landmarks is not None:
            # 绘制关键点和连接线
            mp_drawing = mp.solutions.drawing_utils
            mp_drawing.draw_landmarks(
                frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)

        # 显示结果
        cv2.imshow('MediaPipe Pose Estimation', frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    # 释放资源
    cap.release()
    cv2.destroyAllWindows()

if __name__ == '__main__':
    video()

运行效果:

3.手部追踪

手部追踪 常用API使用见下表:

手部跟踪实例。

1.手部跟踪(图像)

import cv2
import mediapipe as mp

mp_hands = mp.solutions.hands

def main():
    #cv2.namedWindow("MediaPipe Hand", cv2.WINDOW_NORMAL)
    cv2.namedWindow("MediaPipe Hand", cv2.WINDOW_AUTOSIZE)
    hands = mp_hands.Hands(static_image_mode=False, max_num_hands=2,
                           min_detection_confidence=0.5, min_tracking_confidence=0.5)
    img = cv2.imread('images/hand003.jpg')
    rgb_frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # 进行手部追踪
    results = hands.process(rgb_frame)

    if results.multi_hand_landmarks:
        # 绘制手部关键点和连接线
        for hand_landmarks in results.multi_hand_landmarks:
            mp_drawing = mp.solutions.drawing_utils
            mp_drawing.draw_landmarks(
                img, hand_landmarks, mp_hands.HAND_CONNECTIONS)

    # 显示结果
    cv2.imshow('MediaPipe Hand', img)
    cv2.waitKey(0)


if __name__ == '__main__':
    main()

运行效果:

2.手部跟踪(视频)

import cv2
import mediapipe as mp

mp_hands = mp.solutions.hands

def video():

    hands = mp_hands.Hands(static_image_mode=False, max_num_hands=2,
                           min_detection_confidence=0.4, min_tracking_confidence=0.4)

    # 读取视频
    cap = cv2.VideoCapture('video/handguesture01.mp4')

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # 将 BGR 图像转换为 RGB
        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        height, width = rgb_frame.shape[:2]
        # 进行手部追踪
        results = hands.process(rgb_frame)

        if results.multi_hand_landmarks:
            # 绘制手部关键点和连接线
            for hand_landmarks in results.multi_hand_landmarks:
                mp_drawing = mp.solutions.drawing_utils
                mp_drawing.draw_landmarks(
                    frame, hand_landmarks, mp_hands.HAND_CONNECTIONS)
        resized_frame = cv2.resize(frame, (int(width / 2), int(height / 2)), interpolation=cv2.INTER_AREA)

        # 显示结果
        cv2.imshow('MediaPipe Hand Tracking', resized_frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    # 释放资源
    cap.release()
    cv2.destroyAllWindows()


if __name__ == '__main__':
    video()

运行效果: