Skip to content

aigve.core

VQALoop

Bases: BaseLoop

Loop for VQA metric evaluation.

Parameters:

Name Type Description Default
runner Runner

A reference of runner.

required
dataloader Dataloader or dict

A dataloader object or a dict to build a dataloader.

required
evaluator Evaluator or dict or list

Used for computing metrics.

required
fp16 bool

Whether to enable fp16 validation. Defaults to False.

False
Source code in aigve/core/loops.py
@LOOPS.register_module()
class VQALoop(BaseLoop):
    """Loop for VQA metric evaluation.

    Args:
        runner (Runner): A reference of runner.
        dataloader (Dataloader or dict): A dataloader object or a dict to
            build a dataloader.
        evaluator (Evaluator or dict or list): Used for computing metrics.
        fp16 (bool): Whether to enable fp16 validation. Defaults to
            False.
    """

    def __init__(self,
                 runner,
                 dataloader: Union[DataLoader, Dict],
                 evaluator: Union[Evaluator, Dict, List],
                 fp16: bool = False) -> None:
        super().__init__(runner, dataloader)
        if isinstance(evaluator, (dict, list)):
            self.evaluator = runner.build_evaluator(evaluator)  # type: ignore
        else:
            assert isinstance(evaluator, Evaluator), (
                'evaluator must be one of dict, list or Evaluator instance, '
                f'but got {type(evaluator)}.')
            self.evaluator = evaluator  # type: ignore
        if hasattr(self.dataloader.dataset, 'metainfo'):
            self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
            self.runner.visualizer.dataset_meta = \
                self.dataloader.dataset.metainfo
        else:
            print_log(
                f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
                'metainfo. ``dataset_meta`` in evaluator, metric and '
                'visualizer will be None.',
                logger='current',
                level=logging.WARNING)
        self.fp16 = fp16
        self.val_loss: Dict[str, HistoryBuffer] = dict()

    def run(self) -> dict:
        """Launch validation."""
        self.runner.call_hook('before_val')
        self.runner.call_hook('before_val_epoch')
        self.runner.model.eval()

        # clear val loss
        # self.val_loss.clear()
        for idx, data_batch in enumerate(self.dataloader):
            self.run_iter(idx, data_batch)

        # compute metrics
        metrics = self.evaluator.evaluate(len(self.dataloader.dataset))

        # if self.val_loss:
        #     loss_dict = _parse_losses(self.val_loss, 'val')
        #     metrics.update(loss_dict)

        self.runner.call_hook('after_val_epoch', metrics=metrics)
        self.runner.call_hook('after_val')
        return metrics

    def run_iter(self, idx, data_batch: Sequence[dict]):
        """Iterate one mini-batch.

        Args:
            data_batch (Sequence[dict]): Batch of data
                from dataloader.
        """
        self.runner.call_hook(
            'before_val_iter', batch_idx=idx, data_batch=data_batch)
        # outputs should be sequence of BaseDataElement
        # with autocast(enabled=self.fp16):
        #     outputs = self.runner.model.val_step(data_batch)
        outputs = data_batch

        # outputs, self.val_loss = _update_losses(outputs, self.val_loss)

        self.evaluator.process(data_batch=data_batch, data_samples=outputs)
        self.runner.call_hook(
            'after_val_iter',
            batch_idx=idx,
            data_batch=data_batch,
            outputs=outputs)

run()

Launch validation.

Source code in aigve/core/loops.py
def run(self) -> dict:
    """Launch validation."""
    self.runner.call_hook('before_val')
    self.runner.call_hook('before_val_epoch')
    self.runner.model.eval()

    # clear val loss
    # self.val_loss.clear()
    for idx, data_batch in enumerate(self.dataloader):
        self.run_iter(idx, data_batch)

    # compute metrics
    metrics = self.evaluator.evaluate(len(self.dataloader.dataset))

    # if self.val_loss:
    #     loss_dict = _parse_losses(self.val_loss, 'val')
    #     metrics.update(loss_dict)

    self.runner.call_hook('after_val_epoch', metrics=metrics)
    self.runner.call_hook('after_val')
    return metrics

run_iter(idx, data_batch)

Iterate one mini-batch.

Parameters:

Name Type Description Default
data_batch Sequence[dict]

Batch of data from dataloader.

required
Source code in aigve/core/loops.py
def run_iter(self, idx, data_batch: Sequence[dict]):
    """Iterate one mini-batch.

    Args:
        data_batch (Sequence[dict]): Batch of data
            from dataloader.
    """
    self.runner.call_hook(
        'before_val_iter', batch_idx=idx, data_batch=data_batch)
    # outputs should be sequence of BaseDataElement
    # with autocast(enabled=self.fp16):
    #     outputs = self.runner.model.val_step(data_batch)
    outputs = data_batch

    # outputs, self.val_loss = _update_losses(outputs, self.val_loss)

    self.evaluator.process(data_batch=data_batch, data_samples=outputs)
    self.runner.call_hook(
        'after_val_iter',
        batch_idx=idx,
        data_batch=data_batch,
        outputs=outputs)