Source code for torchhd.datasets.emg_hand_gestures

#
# MIT License
#
# Copyright (c) 2023 Mike Heddes, Igor Nunes, Pere Vergés, Denis Kleyko, and Danny Abraham
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import os
import os.path
from typing import Callable, Optional, Tuple, List
import torch
from torch.utils import data
import pandas as pd
import math

from .utils import download_file_from_google_drive, unzip_file


[docs] class EMGHandGestures(data.Dataset): """EMG-based hand gestures dataset. Dataset from the paper `"Hyperdimensional Biosignal Processing: A Case Study for EMG-based Hand Gesture Recognition" <https://iis-people.ee.ethz.ch/~arahimi/papers/ICRC16.pdf>`_. Args: root (string): Root directory of dataset where files are stored. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an torch.FloatTensor and returns a transformed version. target_transform (callable, optional): A function/transform that takes in the target and transforms it. subjects (list[int], optional): The subject numbers from 0 til 4 to include. Defaults to [0, 1, 2, 3, 4]. window (int, optional): The number of measurements to include in each sample. Defaults to 256. """ classes: List[str] = [ "Closed hand", "Open hand", "Two-finger pinch", "Point index", "Rest position", ] features_files = [ "COMPLETE_1.csv", "COMPLETE_2.csv", "COMPLETE_3.csv", "COMPLETE_4.csv", "COMPLETE_5.csv", ] labels_files = [ "LABEL_1.csv", "LABEL_2.csv", "LABEL_3.csv", "LABEL_4.csv", "LABEL_5.csv", ] def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, subjects: list = [0, 1, 2, 3, 4], window: int = 256, ): root = os.path.join(root, "EMG_based_hand_gesture") root = os.path.expanduser(root) self.root = root os.makedirs(self.root, exist_ok=True) self.transform = transform self.target_transform = target_transform self.subjects = subjects self.window = window if download: self.download() if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted. You can use download=True to download it" ) self._load_data() def __len__(self) -> int: return self.data.size(0) def __getitem__(self, index: int) -> Tuple[torch.FloatTensor, torch.LongTensor]: """ Args: index (int): Index Returns: Tuple[torch.FloatTensor, torch.LongTensor]: (sample, target) where target is the index of the target class """ sample = self.data[index] label = self.targets[index] if self.transform: sample = self.transform(sample) if self.target_transform: label = self.target_transform(label) return sample, label def _check_integrity(self) -> bool: if not os.path.isdir(self.root): return False has_not_feature_files = sum( list( map( lambda x: not os.path.isfile(os.path.join(self.root, x)), self.features_files, ) ) ) has_not_label_files = sum( list( map( lambda x: not os.path.isfile(os.path.join(self.root, x)), self.labels_files, ) ) ) # Check if the root directory contains the required files if not has_not_feature_files and not has_not_label_files: return True return False def _load_data(self): features = torch.empty(0, dtype=torch.long) labels = torch.empty(0, dtype=torch.long) for i in self.subjects: complete = pd.read_csv( os.path.join(self.root, self.features_files[i]), header=None ) label = pd.read_csv( os.path.join(self.root, self.labels_files[i]), header=None ) # List of indices where the gesture changes indexes = [ index for index, _ in enumerate(label.values) if label.values[index] != label.values[index - 1] ] prev = 0 labels_clean = torch.empty(0, dtype=torch.long) features_clean = torch.empty(0, self.window, 4, dtype=torch.long) # Every change of gesture we group it values for j in indexes: span = j - prev # If we have that the amount of data of the gesture fits in the window we have a new sample of the gesture of size window if span > self.window: for k in range(math.floor(span / self.window)): # Clean the label data label_clean = ( torch.tensor( label.values[prev + (self.window * k)], dtype=torch.long ) - 1 ) # Clean the feature data feature_clean = torch.tensor( complete.values[ prev + (self.window * k) : prev + (self.window * (k + 1)) ], dtype=torch.long, )[None, :, :] labels_clean = torch.cat((labels_clean, label_clean)) features_clean = torch.cat((features_clean, feature_clean)) prev = j features = torch.cat((features, features_clean)) labels = torch.cat((labels, labels_clean)) self.data = features self.targets = labels def download(self): """Download the data if it doesn't exist already.""" if self._check_integrity(): print("Files already downloaded and verified") return zip_file_path = os.path.join(self.root, "data.zip") download_file_from_google_drive( "1_9DYP1MwICtnMRm_F34xVvv1S2OhVnHH", # Google Drive shared file ID zip_file_path, ) unzip_file(zip_file_path, self.root) os.remove(zip_file_path)