Source code for UCTB.dataset.dataset

import os
import wget
import pickle
import tarfile


[docs]class DataSet(object): """An object storing basic data from a formatted pickle file. See also `Build your own datasets <https://di-chai.github.io/UCTB/static/tutorial.html>`_. Args: dataset (str): A string containing path of the dataset pickle file or a string of name of the dataset. city (str or ``None``): ``None`` if dataset is file path, or a string of name of the city. Default: ``None`` data_dir (str or ``None``): The dataset directory. If set to ``None``, a directory will be created. If ``dataset`` is file path, ``data_dir`` should be ``None`` too. Default: ``None`` Attributes: data (dict): The data directly from the pickle file. ``data`` may have a ``data['contribute_data']`` dict to store supplementary data. time_range (list): From ``data['TimeRange']`` in the format of [YYYY-MM-DD, YYYY-MM-DD] indicating the time range of the data. time_fitness (int): From ``data['TimeFitness']`` indicating how many minutes is a single time slot. node_traffic (np.ndarray): Data recording the main stream data of the nodes in during the time range. From ``data['Node']['TrafficNode']`` with shape as [time_slot_num, node_num]. node_monthly_interaction (np.ndarray): Data recording the monthly interaction of pairs of nodes. Its shape is [month_num, node_num, node_num].It's from ``data['Node']['TrafficMonthlyInteraction']`` and is used to build interaction graph. Its an optional attribute and can be set as an empty list if interaction graph is not needed. node_station_info (dict): A dict storing the coordinates of nodes. It shall be formatted as {id (may be arbitrary): [id (when sorted, should be consistant with index of ``node_traffic``), latitude, longitude, other notes]}. It's from ``data['Node']['StationInfo']`` and is used to build distance graph. Its an optional attribute and can be set as an empty list if distance graph is not needed. """ def __init__(self, dataset, city=None, data_dir=None): self.dataset = dataset self.city = city if data_dir is None: data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data') if os.path.isdir(data_dir) is False: os.makedirs(data_dir) if self.city is not None: pkl_file_name = os.path.join(data_dir, '{}_{}.pkl'.format(self.dataset, self.city)) else: pkl_file_name = self.dataset if os.path.isfile(pkl_file_name) is False: try: tar_file_name = os.path.join(data_dir, '{}_{}.tar.gz'.format(self.dataset, self.city)) if os.path.isfile(tar_file_name) is False: print('Downloading data into', data_dir) wget.download('https://github.com/Di-Chai/UCTB_Data/blob/master/%s_%s.tar.gz?raw=true' % (dataset, city), tar_file_name) print('Download succeed') else: print('Found', tar_file_name) tar = tarfile.open(tar_file_name, "r:gz") file_names = tar.getnames() for file_name in file_names: tar.extract(file_name, data_dir) tar.close() os.remove(tar_file_name) except Exception as e: print(e) raise FileExistsError('Download Failed') with open(pkl_file_name, 'rb') as f: self.data = pickle.load(f) self.time_range = self.data['TimeRange'] self.time_fitness = self.data['TimeFitness'] self.node_traffic = self.data['Node']['TrafficNode'] self.node_monthly_interaction = self.data['Node']['TrafficMonthlyInteraction'] self.node_station_info = self.data['Node']['StationInfo'] self.grid_traffic = self.data['Grid']['TrafficGrid'] self.grid_lat_lng = self.data['Grid']['GridLatLng'] self.external_feature_weather = self.data['ExternalFeature']['Weather']