"""
Getting data into the required format for use with this library.
There are 3 useful ways to represent batch data.
``dict``: as a Python dictionary. Example::
data = {
"batch 1": data frame with varying number of rows, but same number of columns,
"batch 2": etc,
}
The keys are unique identifiers for each batch, such as integers or strings.
``melt``: as a single Pandas data frame::
data = pd.DataFrame(...)
Characteristics:
- very large number of rows, for all batches stacked vertically on top of each other
- some number of columns, one column per tag
- one column, usually called ``batch_id``, indicates what the batch number is for that row
- another column, usually called ``time``, indicates what the time is within that batch
- typically sorted, but does not have to be
``wide``: as a single Pandas data frame, as for the "melted" version, but pivoted instead.
These ``wide`` dataframes *always* have a multilevel column index to distinguish the tags
from the time. This representation is only valid for aligned data. Example::
data = pd.DataFrame(...)
Characteristics:
- each row is a unique batch number
- the multilevel column index has level 0 = column name, level 1 = aligned time
- only makes sense if the data are aligned (same number of elements in each level-1 index)
"""
from __future__ import annotations
import numpy as np
import pandas as pd
[docs]
def check_valid_batch_dict(in_dict: dict, no_nan: bool = False) -> bool:
"""Check if the incoming dictionary of batch data is a valid dictionary of data.
Checks:
1. All batches in the dictionary have the same number of columns.
2. All columns are numeric.
3. If `no_nan` is True, also checks that there are no NaNs.
Parameters
----------
in_dict : dict
A dictionary of batch data.
no_nan : bool
If True, will also check that no missing values are present.
Returns
-------
bool
True, if it passes the checks.
"""
assert len(in_dict) >= 1, "At least 1 batch is required in the dataframe dictionary."
batch1 = in_dict[next(iter(in_dict.keys()))]
base_columns = set(batch1.columns)
check = True
for bid, batch in in_dict.items():
# Check 1
check = check & (base_columns == set(batch.columns))
assert check, (
f"The column names must be the same in all batches. Differs in {bid}. Base "
f"columns = {base_columns}; this batch has: {set(batch.columns)}"
)
# Check 2
check *= batch.select_dtypes(include=[np.number]).shape[1] == batch.shape[1]
assert check, f"All columns must be a numeric type. Differs in {bid}."
# Check 3
if no_nan:
check *= batch.isna().values.sum() == 0
assert check, f"No missing values allowed. Missing values found in {bid}."
return bool(check)
[docs]
def dict_to_melted(
in_df: pd.DataFrame,
insert_batch_id_column: bool = True,
insert_sequence_column: bool = False,
) -> pd.DataFrame:
"""Reverse of `melted_to_dict`."""
batch_id_col = "batch_id"
all_batches = []
num_rows = 0
for idx, (batch_id, batch) in enumerate(in_df.items()):
if idx == 0:
num_rows = batch.shape[0]
sequence = np.arange(0, num_rows)
assert num_rows == batch.shape[0], "All batches must have the same number of samples"
subset = batch.copy()
if insert_batch_id_column and batch_id_col not in batch:
subset.insert(0, batch_id_col, batch_id)
if insert_sequence_column:
subset.insert(0, "_sequence_", sequence)
all_batches.append(subset)
return pd.concat(all_batches)
[docs]
def dict_to_wide(in_df: dict, group_by_batch: bool = False) -> pd.DataFrame:
"""
Convert aligned batch data from dict to wide format.
`group_by_batch`, if True, means that all the data from the first batch is on the left
of the output dataframe, and the last batch is collected on the right.
If `group_by_batch` is False, then data for the same tag are grouped together, side-by-side.
TODO: `group_by_batch` is not implemented yet.
"""
out_df = dict_to_melted(in_df=in_df, insert_batch_id_column=True, insert_sequence_column=True)
aligned_wide_df = out_df.pivot_table(index="batch_id", columns="_sequence_")
if group_by_batch:
pass
# TODO: use the hierarchical indexing and regroup the columns
return aligned_wide_df
[docs]
def melted_to_dict(in_df: pd.DataFrame, batch_id_col: str) -> dict:
"""
Load a "melted" data set, where one of the columns is the `batch_id_col`.
The data are grouped along the unique values of `batch_id_col`, and each group is stored
in a dictionary. The dictionary keys are the batch identifier, and the corresponding value
is a Pandas dataframe of the batch data for that batch.
"""
assert batch_id_col in in_df, "The `batch_id_col` column does not exist in the incoming dataframe."
return {batch_id: batch for batch_id, batch in in_df.groupby(batch_id_col)} # noqa: C416
[docs]
def melted_to_wide(in_df: pd.DataFrame, batch_id_col: str) -> dict:
"""Convert aligned melted data to wide format."""
assert batch_id_col in in_df
return {}
# max_places = int(np.ceil(np.log10(aligned_df["_sequence_"].max())))
# aligned_wide_df = aligned_df.pivot(index="batch_id", columns="_sequence_")
# new_labels = [
# "-".join(item)
# for item in zip(
# aligned_wide_df.columns.get_level_values(0),
# [str(val).zfill(max_places) for val in aligned_wide_df.columns.get_level_values(1)],
# )
# ]
# aligned_wide_df.columns = new_labels
# TODO: add the column multilevel column index.
# return dict_to_wide(melted_to_dict(in_df, batch_id_col))
[docs]
def wide_to_melted(in_df: pd.DataFrame) -> pd.DataFrame:
"""Convert wide-format batch data to melted format. Not yet implemented."""
# dict_to_melted(dict_to_wide(in_df))
return pd.DataFrame()
[docs]
def wide_to_dict() -> None:
"""Convert wide-format batch data to dict format. Not yet implemented."""
[docs]
def melt_df_to_series(in_df: pd.DataFrame, exclude_columns: list | None = None, name: str | None = None) -> pd.Series:
"""Return a Series with a multilevel-index, melted from the DataFrame."""
if exclude_columns is None:
exclude_columns = ["batch_id"]
out = in_df.drop(exclude_columns, axis=1).T.stack() # noqa: PD013 # noqa: PD013
out.name = name
return out