Source code for torchhd.datasets.beijing_air_quality

#
# MIT License
#
# Copyright (c) 2023 Mike Heddes, Igor Nunes, Pere Vergés, Denis Kleyko, and Danny Abraham
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import os
import os.path
from typing import Callable, Optional, NamedTuple, List
import torch
from torch.utils import data
import pandas as pd

from .utils import download_file, unzip_file


class BeijingAirQualityDataSample(NamedTuple):
    categorical: torch.LongTensor
    continuous: torch.FloatTensor


[docs] class BeijingAirQuality(data.Dataset): """`Beijing Multi-Site Air-Quality <https://archive.ics.uci.edu/ml/datasets/Beijing+Multi-Site+Air-Quality+Data>`_ dataset. .. list-table:: :widths: 10 10 10 10 :align: center :header-rows: 1 * - Instances - Attributes - Task - Area * - 420768 - 18 - Regression - Physical .. warning:: The data contains NaN values that need to be taken into account. Args: root (string): Root directory of dataset where directory ``beijing-air-quality`` exists or will be saved to if download is set to True. transform (callable, optional): A function/transform that takes in a feature Tensor and returns a transformed version. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ wind_directions: List[str] stations: List[str] categorical_columns: List[str] = [ "year", "month", "day", "hour", "wind_direction", "station", ] continuous_columns: List[str] = [ "PM2.5", "PM10", "SO2", "NO2", "CO", "O3", "temperature", "pressure", "dew_point_temperature", "precipitation", "wind_speed", ] def __init__( self, root: str, transform: Optional[Callable] = None, download: bool = False, ): root = os.path.join(root, "beijing-air-quality") root = os.path.expanduser(root) self.root = root os.makedirs(self.root, exist_ok=True) self.transform = transform if download: self.download() if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted. You can use download=True to download it" ) self._load_data() def __len__(self) -> int: return self.categorical_data.size(0) def __getitem__(self, index: int) -> BeijingAirQualityDataSample: """ Args: index (int): Index Returns: BeijingAirQualityDataSample: Indexed Sample """ sample = BeijingAirQualityDataSample( self.categorical_data[index], self.continuous_data[index] ) if self.transform: sample = self.transform(sample) return sample def _check_integrity(self) -> bool: if not os.path.isdir(self.root): return False # Check if the root directory contains the expected number of csv files files = os.listdir(self.root) files = [file for file in files if file.endswith(".csv")] if len(files) == 12: return True # TODO: Add more specific checks like an MD5 checksum return False def _load_data(self): files = os.listdir(self.root) csv_files = [file for file in files if file.endswith(".csv")] data_tables = [] for filename in csv_files: data = pd.read_csv(os.path.join(self.root, filename)) # No column does not provide any meaningful information data = data.drop(columns=["No"]) data_tables.append(data) data = pd.concat(data_tables, ignore_index=True) # Rename columns for accessability data = data.rename( columns={ "wd": "wind_direction", "TEMP": "temperature", "PRES": "pressure", "DEWP": "dew_point_temperature", "RAIN": "precipitation", "WSPM": "wind_speed", } ) # Change null values to Not Available string "N/A" data.loc[data.wind_direction.isnull(), "wind_direction"] = "N/A" # Map the wind directions to a category identifier (int) # # Save the mapping for user referencing self.wind_directions = tuple(sorted(list(set(data.wind_direction)))) data.wind_direction = data.wind_direction.apply( lambda x: self.wind_directions.index(x) ) # Map the stations to a category identifier (int) # Save the mapping for user referencing self.stations = tuple(sorted(list(set(data.station)))) data.station = data.station.apply(lambda x: self.stations.index(x)) categorical = data[self.categorical_columns] self.categorical_data = torch.tensor(categorical.values, dtype=torch.long) continuous = data[self.continuous_columns] self.continuous_data = torch.tensor(continuous.values, dtype=torch.float) def download(self): """Download the data if it doesn't exist already.""" if self._check_integrity(): print("Files already downloaded and verified") return zip_file_path = os.path.join(self.root, "data.zip") download_file( "https://archive.ics.uci.edu/ml/machine-learning-databases/00501/PRSA2017_Data_20130301-20170228.zip", zip_file_path, ) unzip_file(zip_file_path, self.root) os.remove(zip_file_path) source_dir = os.path.join(self.root, "PRSA_Data_20130301-20170228") data_files = os.listdir(source_dir) for filename in data_files: os.rename( os.path.join(source_dir, filename), os.path.join(self.root, filename) ) os.rmdir(source_dir)