""" Manages factor information for a tabular file. """
import pandas as pd
from hed.errors.exceptions import HedFileError
[docs]class HedTypeFactors:
""" Holds index of positions for a variable type for one tabular file. """
ALLOWED_ENCODINGS = ("categorical", "one-hot")
[docs] def __init__(self, type_tag, type_value, number_elements):
""" Constructor for HedTypeFactors.
Parameters:
type_tag (str): Lowercase string corresponding to a HED tag which has a takes value child.
type_value (str): The value of the type summarized by this class.
number_elements (int): Number of elements in the data column
"""
self.type_value = type_value
self.number_elements = number_elements
self.type_tag = type_tag.lower()
self.levels = {}
self.direct_indices = {}
def __str__(self):
return f"[{self.type_value},{self.type_tag}]: {self.number_elements} elements " + \
f"{str(self.levels)} levels {len(self.direct_indices)} references"
[docs] def get_factors(self, factor_encoding="one-hot"):
""" Return a DataFrame of factor vectors for this type factor.
Parameters:
factor_encoding (str): Specifies type of factor encoding (one-hot or categorical).
Returns:
DataFrame: DataFrame containing the factor vectors as the columns.
"""
if not self.levels:
df = pd.DataFrame(0, index=range(self.number_elements), columns=[self.type_value])
df.loc[list(self.direct_indices.keys()), [self.type_value]] = 1
return df
levels = list(self.levels.keys())
levels_list = [f"{self.type_value}.{level}" for level in levels]
factors = pd.DataFrame(0, index=range(self.number_elements), columns=levels_list)
for index, level in enumerate(levels):
index_keys = list(self.levels[level].keys())
factors.loc[index_keys, [levels_list[index]]] = 1
if factor_encoding == "one-hot":
return factors
sum_factors = factors.sum(axis=1)
if factor_encoding == "categorical" and sum_factors.max() > 1:
raise HedFileError("MultipleFactorSameEvent",
f"{self.type_value} has multiple occurrences at index {sum_factors.idxmax()}", "")
elif factor_encoding == "categorical":
return self._one_hot_to_categorical(factors, levels)
else:
raise ValueError("BadFactorEncoding",
f"{factor_encoding} is not in the allowed encodings: {str(self.ALLOWED_ENCODINGS)}")
def _one_hot_to_categorical(self, factors, levels):
df = pd.DataFrame('n/a', index=range(len(factors.index)), columns=[self.type_value])
for index, row in factors.iterrows():
if self.type_value in row.index and row[self.type_value]:
df.at[index, self.type_value] = self.type_value
continue
for level in levels:
level_str = f"{self.type_value}.{level.lower()}"
if level_str in row.index and row[level_str]:
df.at[index, self.type_value] = level.lower()
break
return df
[docs] def get_summary(self):
count_list = [0] * self.number_elements
for index in list(self.direct_indices.keys()):
count_list[index] = count_list[index] + 1
for level, cond in self.levels.items():
for index, item in cond.items():
count_list[index] = count_list[index] + 1
number_events, number_multiple, max_multiple = self._count_level_events(count_list)
summary = {'type_value': self.type_value, 'type_tag': self.type_tag,
'levels': len(self.levels.keys()), 'direct_references': len(self.direct_indices.keys()),
'total_events': self.number_elements, 'events': number_events,
'events_with_multiple_refs': number_multiple, 'max_refs_per_event': max_multiple,
'level_counts': self._get_level_counts()}
return summary
def _get_level_counts(self):
count_dict = {}
for level, cond in self.levels.items():
count_dict[level] = len(cond.values())
return count_dict
@staticmethod
def _count_level_events(count_list):
""" Count the number of events and multiples in a list.
Parameters:
count_list (list): list of integers of the number of times a level occurs in an event.
Returns:
int: Number of events this level
"""
if not len(count_list):
return 0, 0, None
number_events = 0
number_multiple = 0
max_multiple = count_list[0]
for index, count in enumerate(count_list):
if count_list[index] > 0:
number_events = number_events + 1
if count_list[index] > 1:
number_multiple = number_multiple + 1
if count_list[index] > max_multiple:
max_multiple = count_list[index]
return number_events, number_multiple, max_multiple