Skip to content

DataManager

BaseDataManager

Bases: ABC

Base data manager for loading and saving data.

Source code in utu/eval/data/data_manager.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class BaseDataManager(abc.ABC):
    """Base data manager for loading and saving data."""

    data: list[EvaluationSample]

    def __init__(self, config: EvalConfig) -> None:
        self.config = config

    @abc.abstractmethod
    def load(self) -> list[EvaluationSample]:
        """Load the dataset."""
        raise NotImplementedError

    @abc.abstractmethod
    def save(self, **kwargs) -> None:
        """Save the dataset."""
        raise NotImplementedError

    @abc.abstractmethod
    def get_samples(self, stage: Literal["init", "rollout", "judged"] = None) -> list[EvaluationSample]:
        """Get samples of specified stage from the dataset."""
        raise NotImplementedError

load abstractmethod

load() -> list[EvaluationSample]

Load the dataset.

Source code in utu/eval/data/data_manager.py
21
22
23
24
@abc.abstractmethod
def load(self) -> list[EvaluationSample]:
    """Load the dataset."""
    raise NotImplementedError

save abstractmethod

save(**kwargs) -> None

Save the dataset.

Source code in utu/eval/data/data_manager.py
26
27
28
29
@abc.abstractmethod
def save(self, **kwargs) -> None:
    """Save the dataset."""
    raise NotImplementedError

get_samples abstractmethod

get_samples(
    stage: Literal["init", "rollout", "judged"] = None,
) -> list[EvaluationSample]

Get samples of specified stage from the dataset.

Source code in utu/eval/data/data_manager.py
31
32
33
34
@abc.abstractmethod
def get_samples(self, stage: Literal["init", "rollout", "judged"] = None) -> list[EvaluationSample]:
    """Get samples of specified stage from the dataset."""
    raise NotImplementedError

DBDataManager

Bases: BaseDataManager

Database data manager for loading and saving data.

Source code in utu/eval/data/data_manager.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class DBDataManager(BaseDataManager):
    """Database data manager for loading and saving data."""

    def __init__(self, config: EvalConfig) -> None:
        self.config = config

    def load(self) -> list[EvaluationSample]:
        if self._check_exp_id():
            logger.warning(f"exp_id {self.config.exp_id} already exists in db")
            return self.get_samples()

        with SQLModelUtils.create_session() as session:
            datapoints = session.exec(
                select(DatasetSample).where(DatasetSample.dataset == self.config.data.dataset)
            ).all()
            logger.info(f"Loaded {len(datapoints)} samples from {self.config.data.dataset}.")
            samples = []
            for dp in datapoints:
                sample = EvaluationSample(
                    dataset=dp.dataset,
                    dataset_index=dp.index,
                    source=dp.source,
                    raw_question=dp.question,
                    level=dp.level,
                    correct_answer=dp.answer,
                    file_name=dp.file_name,
                    meta=dp.meta,
                    exp_id=self.config.exp_id,  # add exp_id
                )
                samples.append(sample)

            self.data = samples
            self.save(self.data)  # save to db
            return self.data

    def get_samples(
        self, stage: Literal["init", "rollout", "judged"] = None, limit: int = None
    ) -> list[EvaluationSample]:
        """Get samples from exp_id with specified stage."""
        with SQLModelUtils.create_session() as session:
            samples = session.exec(
                select(EvaluationSample)
                .where(
                    EvaluationSample.exp_id == self.config.exp_id,
                    EvaluationSample.stage == stage if stage else True,
                )
                .order_by(EvaluationSample.dataset_index)
                .limit(limit)
            ).all()
            return samples

    def save(self, samples: list[EvaluationSample] | EvaluationSample) -> None:
        """Update or add sample(s) to db."""
        if isinstance(samples, list):
            with SQLModelUtils.create_session() as session:
                session.add_all(samples)
                session.commit()
        else:
            with SQLModelUtils.create_session() as session:
                session.add(samples)
                session.commit()

    def delete_samples(self, samples: list[EvaluationSample] | EvaluationSample) -> None:
        """Delete sample(s) from db."""
        if isinstance(samples, list):
            with SQLModelUtils.create_session() as session:
                for sample in samples:
                    session.delete(sample)
                session.commit()
        else:
            with SQLModelUtils.create_session() as session:
                session.delete(samples)
                session.commit()

    def _check_exp_id(self) -> bool:
        # check if any record has the same exp_id
        with SQLModelUtils.create_session() as session:
            has_exp_id = session.exec(
                select(EvaluationSample).where(EvaluationSample.exp_id == self.config.exp_id)
            ).first()
        return has_exp_id is not None

get_samples

get_samples(
    stage: Literal["init", "rollout", "judged"] = None,
    limit: int = None,
) -> list[EvaluationSample]

Get samples from exp_id with specified stage.

Source code in utu/eval/data/data_manager.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def get_samples(
    self, stage: Literal["init", "rollout", "judged"] = None, limit: int = None
) -> list[EvaluationSample]:
    """Get samples from exp_id with specified stage."""
    with SQLModelUtils.create_session() as session:
        samples = session.exec(
            select(EvaluationSample)
            .where(
                EvaluationSample.exp_id == self.config.exp_id,
                EvaluationSample.stage == stage if stage else True,
            )
            .order_by(EvaluationSample.dataset_index)
            .limit(limit)
        ).all()
        return samples

save

save(
    samples: list[EvaluationSample] | EvaluationSample,
) -> None

Update or add sample(s) to db.

Source code in utu/eval/data/data_manager.py
88
89
90
91
92
93
94
95
96
97
def save(self, samples: list[EvaluationSample] | EvaluationSample) -> None:
    """Update or add sample(s) to db."""
    if isinstance(samples, list):
        with SQLModelUtils.create_session() as session:
            session.add_all(samples)
            session.commit()
    else:
        with SQLModelUtils.create_session() as session:
            session.add(samples)
            session.commit()

delete_samples

delete_samples(
    samples: list[EvaluationSample] | EvaluationSample,
) -> None

Delete sample(s) from db.

Source code in utu/eval/data/data_manager.py
 99
100
101
102
103
104
105
106
107
108
109
def delete_samples(self, samples: list[EvaluationSample] | EvaluationSample) -> None:
    """Delete sample(s) from db."""
    if isinstance(samples, list):
        with SQLModelUtils.create_session() as session:
            for sample in samples:
                session.delete(sample)
            session.commit()
    else:
        with SQLModelUtils.create_session() as session:
            session.delete(samples)
            session.commit()