Skip to content

aigve

This aigve library provides a comprehensive and structured evaluation framework for assessing AI-generated video quality. It integrates multiple evaluation metrics, covering diverse aspects of video evaluation, including neural-network-based assessment, distribution comparison, vision-language alignment, and multi-faceted analysis.

CLIPSimScore

Bases: BaseMetric

Initialize the CLIPSimScore evaluator.

Parameters:

Name Type Description Default
processor_name str

The name of the CLIP processor, which wraps a CLIP feature extractor and a CLIP tokenizer into this single procesor. Defaults to openai/clip-vit-base-patch32.

'openai/clip-vit-base-patch32'
model_name str

The name of the CLIP model. Defaults to openai/clip-vit-base-patch32.

'openai/clip-vit-base-patch32'
logit_scale bool

Whether to calcualte the cosine similarity as logits. Defaults to False.

False
Source code in aigve/metrics/text_video_alignment/similarity_based/clipscore/clipsim.py
@METRICS.register_module()
class CLIPSimScore(BaseMetric):
    """ Initialize the ``CLIPSimScore`` evaluator.

    Args:
        processor_name (str): The name of the CLIP processor, which wraps a CLIP feature extractor and a CLIP tokenizer into this single procesor. 
                                Defaults to ``openai/clip-vit-base-patch32``.
        model_name (str): The name of the CLIP model. Defaults to ``openai/clip-vit-base-patch32``.
        logit_scale (bool): Whether to calcualte the cosine similarity as logits. Defaults to False.
    """
    def __init__(self,
                 processor_name: str = "openai/clip-vit-base-patch32",
                 model_name: str = "openai/clip-vit-base-patch32",
                 logit_scale: bool = False,
                #  train_index: int = 4
                 ) -> None:
        super().__init__()
        self.processor_name = processor_name
        self.model_name = model_name
        self.logit_scale = logit_scale

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.processor = AutoProcessor.from_pretrained(self.processor_name)
        self.model = CLIPModel.from_pretrained(self.model_name).to(self.device)
        self.model.eval()

    def process(self, data_batch: Sequence, data_samples: Sequence) -> None:
        """CLIPSimScore process
        Process one batch of data samples and predictions. The processed
        results should be stored in ``self.results``, which will be used to
        compute the metrics when all batches have been processed.

        Args:
            data_batch (Sequence): A batch of data from the dataloader.
            data_samples (Sequence): A batch of data samples that
                contain annotations and predictions.
        """

        result = dict()

        input_prompts, input_videos = data_samples
        bsz = len(input_prompts)

        # Ensure prompt_input is a tensor
        if isinstance(input_prompts, tuple):
            input_prompts = list(input_prompts)

        if isinstance(input_videos, tuple):
            input_videos = list(input_videos)

        # Initialize an empty list to store each similarity score
        clip_score_sum, clip_score_cnt = 0, 0
        logit_scale = self.model.logit_scale.exp() if self.logit_scale else 1
        with torch.no_grad():
            for input_prompt, input_frames in zip(input_prompts, input_videos):
                input_prompt = input_prompt.to(self.device)
                text_feature = self.model.get_text_features(input_prompt) # [bsz, hid_dim]
                text_feature = text_feature / torch.norm(text_feature, dim=-1, keepdim=True)

                input_frames = input_frames.to(self.device)  # Add batch dimension and move the frame to the device
                frame_feature = self.model.get_image_features(input_frames)
                frame_feature = frame_feature / torch.norm(frame_feature, dim=-1, keepdim=True)

                clip_score = logit_scale * (frame_feature @ text_feature.T).mean().item()
                print('current clip similarity score', clip_score)
                clip_score_sum += clip_score
                clip_score_cnt += 1

        # Calculate the average CLIP score across all frames
        clip_score_videos_avg = clip_score_sum/clip_score_cnt

        result['clip_sim_score'] = clip_score_videos_avg

        self.results.append(result)


    def compute_metrics(self, results: list) -> Dict[str, float]:
        """Compute the metrics from processed results.

        Args:
            results (list): The processed results of each batch.

        Returns:
            Dict[str, float]: The computed metrics. The keys are the names of
            the metrics, and the values are corresponding results.
        """
        logger: MMLogger = MMLogger.get_current_instance()

        clip_score_np = np.zeros(len(results))
        for i, result in enumerate(results):
            clip_score_np[i] = result['clip_sim_score']

        clip_sim_mean = np.mean(clip_score_np) 

        print("Test results: clip similarity score={:.4f}"
              .format(clip_sim_mean))

        return result

compute_metrics(results)

Compute the metrics from processed results.

Parameters:

Name Type Description Default
results list

The processed results of each batch.

required

Returns:

Type Description
Dict[str, float]

Dict[str, float]: The computed metrics. The keys are the names of

Dict[str, float]

the metrics, and the values are corresponding results.

Source code in aigve/metrics/text_video_alignment/similarity_based/clipscore/clipsim.py
def compute_metrics(self, results: list) -> Dict[str, float]:
    """Compute the metrics from processed results.

    Args:
        results (list): The processed results of each batch.

    Returns:
        Dict[str, float]: The computed metrics. The keys are the names of
        the metrics, and the values are corresponding results.
    """
    logger: MMLogger = MMLogger.get_current_instance()

    clip_score_np = np.zeros(len(results))
    for i, result in enumerate(results):
        clip_score_np[i] = result['clip_sim_score']

    clip_sim_mean = np.mean(clip_score_np) 

    print("Test results: clip similarity score={:.4f}"
          .format(clip_sim_mean))

    return result

process(data_batch, data_samples)

CLIPSimScore process Process one batch of data samples and predictions. The processed results should be stored in self.results, which will be used to compute the metrics when all batches have been processed.

Parameters:

Name Type Description Default
data_batch Sequence

A batch of data from the dataloader.

required
data_samples Sequence

A batch of data samples that contain annotations and predictions.

required
Source code in aigve/metrics/text_video_alignment/similarity_based/clipscore/clipsim.py
def process(self, data_batch: Sequence, data_samples: Sequence) -> None:
    """CLIPSimScore process
    Process one batch of data samples and predictions. The processed
    results should be stored in ``self.results``, which will be used to
    compute the metrics when all batches have been processed.

    Args:
        data_batch (Sequence): A batch of data from the dataloader.
        data_samples (Sequence): A batch of data samples that
            contain annotations and predictions.
    """

    result = dict()

    input_prompts, input_videos = data_samples
    bsz = len(input_prompts)

    # Ensure prompt_input is a tensor
    if isinstance(input_prompts, tuple):
        input_prompts = list(input_prompts)

    if isinstance(input_videos, tuple):
        input_videos = list(input_videos)

    # Initialize an empty list to store each similarity score
    clip_score_sum, clip_score_cnt = 0, 0
    logit_scale = self.model.logit_scale.exp() if self.logit_scale else 1
    with torch.no_grad():
        for input_prompt, input_frames in zip(input_prompts, input_videos):
            input_prompt = input_prompt.to(self.device)
            text_feature = self.model.get_text_features(input_prompt) # [bsz, hid_dim]
            text_feature = text_feature / torch.norm(text_feature, dim=-1, keepdim=True)

            input_frames = input_frames.to(self.device)  # Add batch dimension and move the frame to the device
            frame_feature = self.model.get_image_features(input_frames)
            frame_feature = frame_feature / torch.norm(frame_feature, dim=-1, keepdim=True)

            clip_score = logit_scale * (frame_feature @ text_feature.T).mean().item()
            print('current clip similarity score', clip_score)
            clip_score_sum += clip_score
            clip_score_cnt += 1

    # Calculate the average CLIP score across all frames
    clip_score_videos_avg = clip_score_sum/clip_score_cnt

    result['clip_sim_score'] = clip_score_videos_avg

    self.results.append(result)

CLIPTempDataset

Bases: Dataset

Source code in aigve/datasets/cliptemp_dataset.py
@DATASETS.register_module()
class CLIPTempDataset(Dataset):
    def __init__(self, processor_name, prompt_dir, video_dir):
        super(CLIPTempDataset, self).__init__()
        self.prompt_dir = prompt_dir
        self.video_dir = video_dir
        self.processor_name = processor_name

        self.processor = AutoProcessor.from_pretrained(self.processor_name)
        self.video_names = self._read_videoname()

    def _read_videoname(self):
        with open(self.prompt_dir, 'r') as reader:
            read_data = json.load(reader)

        video_name_list = []
        for item in read_data["datset_list"]:
            video_name = item['video_path_pd'].strip()
            video_name_list.append(video_name)

        return video_name_list

    def __len__(self):
        return len(self.video_names)-1

    def __getitem__(self, index):
        '''return video frame pairs
        '''
        video_name = self.video_names[index]
        video_path = self.video_dir + video_name
        frames = []
        cap = cv2.VideoCapture(video_path)
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            resized_frame = cv2.resize(frame,(224,224))  # Resize the frame to match the expected input size
            frames.append(resized_frame)

        input_frame_tensor = self.processor(
            images=frames,
            padding=True,
            truncation=True,
            max_length=77,
            return_tensors="pt",
        )['pixel_values']

        return input_frame_tensor

__getitem__(index)

return video frame pairs

Source code in aigve/datasets/cliptemp_dataset.py
def __getitem__(self, index):
    '''return video frame pairs
    '''
    video_name = self.video_names[index]
    video_path = self.video_dir + video_name
    frames = []
    cap = cv2.VideoCapture(video_path)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        resized_frame = cv2.resize(frame,(224,224))  # Resize the frame to match the expected input size
        frames.append(resized_frame)

    input_frame_tensor = self.processor(
        images=frames,
        padding=True,
        truncation=True,
        max_length=77,
        return_tensors="pt",
    )['pixel_values']

    return input_frame_tensor

CLIPTempScore

Bases: BaseMetric

Initialize the CLIPTempScore evaluator.

Parameters:

Name Type Description Default
model_name str

The name of the CLIP encoder model. Defaults to openai/clip-vit-base-patch32.

'openai/clip-vit-base-patch32'
logit_scale bool

Whether to calcualte the cosine similarity as logits. Defaults to False.

False
Source code in aigve/metrics/text_video_alignment/similarity_based/clipscore/cliptemp.py
@METRICS.register_module()
class CLIPTempScore(BaseMetric):
    """ Initialize the ``CLIPTempScore`` evaluator.

    Args:
        model_name (str): The name of the CLIP encoder model. Defaults to ``openai/clip-vit-base-patch32``.
        logit_scale (bool): Whether to calcualte the cosine similarity as logits. Defaults to False.

    """
    def __init__(self,
                 model_name: str = "openai/clip-vit-base-patch32",
                 logit_scale: bool = False,
                #  train_index: int = 4
                 ) -> None:
        super().__init__()
        self.model_name = model_name
        self.logit_scale = logit_scale

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = CLIPModel.from_pretrained(self.model_name).to(self.device)
        self.model.eval()

    def process(self, data_batch: Sequence, data_samples: Sequence) -> None:
        """CLIPTempScore process
        Process one batch of data samples and predictions. The processed
        results should be stored in ``self.results``, which will be used to
        compute the metrics when all batches have been processed.

        Args:
            data_batch (Sequence): A batch of data from the dataloader.
            data_samples (Sequence): A batch of data samples that
                contain annotations and predictions.
        """

        result = dict()

        input_videos = data_samples
        # bsz = len(input_videos)


        # Ensure prompt_input is a tensor        
        if isinstance(input_videos, tuple):
            input_videos = list(input_videos)

        # Generate embeddings for each frame and concatenate the features
        clip_temp_score_sum, clip_temp_score_cnt = 0, 0
        logit_scale = self.model.logit_scale.exp() if self.logit_scale else 1
        with torch.no_grad():  
            for input_frames in input_videos: # Too many frames in a video, must split before CLIP embedding, limited by the memory
                input_frames = input_frames.to(self.device)
                frame_feature = self.model.get_image_features(input_frames)
                frame_feature = frame_feature / torch.norm(frame_feature, dim=-1, keepdim=True)
                # print(frame_feature.shape)

                clip_temp_score_list = []
                for i in range(frame_feature.shape[0]-1):
                    clip_temp_score = logit_scale * frame_feature[i].unsqueeze(0) @ frame_feature[i+1].unsqueeze(0).T
                    clip_temp_score = clip_temp_score.item()
                    # print(clip_temp_score)
                    clip_temp_score_list.append(clip_temp_score)
                clip_temp_cur_avg_score = sum(clip_temp_score_list)/len(clip_temp_score_list)
                clip_temp_score_sum += clip_temp_cur_avg_score
                clip_temp_score_cnt += 1
                print('current clip temp similarity score', clip_temp_cur_avg_score)

        clip_temp_score_avg = clip_temp_score_sum/clip_temp_score_cnt

        result['clip_temp_score'] = clip_temp_score_avg

        self.results.append(result)


    def compute_metrics(self, results: list) -> Dict[str, float]:
        """Compute the metrics from processed results.

        Args:
            results (list): The processed results of each batch.

        Returns:
            Dict[str, float]: The computed metrics. The keys are the names of
            the metrics, and the values are corresponding results.
        """
        logger: MMLogger = MMLogger.get_current_instance()

        clip_score_np = np.zeros(len(results))
        for i, result in enumerate(results):
            clip_score_np[i] = result['clip_temp_score']

        clip_temp_mean = np.mean(clip_score_np) 

        print("Test results: clip temporal consistency score={:.4f}"
              .format(clip_temp_mean))

        return result

compute_metrics(results)

Compute the metrics from processed results.

Parameters:

Name Type Description Default
results list

The processed results of each batch.

required

Returns:

Type Description
Dict[str, float]

Dict[str, float]: The computed metrics. The keys are the names of

Dict[str, float]

the metrics, and the values are corresponding results.

Source code in aigve/metrics/text_video_alignment/similarity_based/clipscore/cliptemp.py
def compute_metrics(self, results: list) -> Dict[str, float]:
    """Compute the metrics from processed results.

    Args:
        results (list): The processed results of each batch.

    Returns:
        Dict[str, float]: The computed metrics. The keys are the names of
        the metrics, and the values are corresponding results.
    """
    logger: MMLogger = MMLogger.get_current_instance()

    clip_score_np = np.zeros(len(results))
    for i, result in enumerate(results):
        clip_score_np[i] = result['clip_temp_score']

    clip_temp_mean = np.mean(clip_score_np) 

    print("Test results: clip temporal consistency score={:.4f}"
          .format(clip_temp_mean))

    return result

process(data_batch, data_samples)

CLIPTempScore process Process one batch of data samples and predictions. The processed results should be stored in self.results, which will be used to compute the metrics when all batches have been processed.

Parameters:

Name Type Description Default
data_batch Sequence

A batch of data from the dataloader.

required
data_samples Sequence

A batch of data samples that contain annotations and predictions.

required
Source code in aigve/metrics/text_video_alignment/similarity_based/clipscore/cliptemp.py
def process(self, data_batch: Sequence, data_samples: Sequence) -> None:
    """CLIPTempScore process
    Process one batch of data samples and predictions. The processed
    results should be stored in ``self.results``, which will be used to
    compute the metrics when all batches have been processed.

    Args:
        data_batch (Sequence): A batch of data from the dataloader.
        data_samples (Sequence): A batch of data samples that
            contain annotations and predictions.
    """

    result = dict()

    input_videos = data_samples
    # bsz = len(input_videos)


    # Ensure prompt_input is a tensor        
    if isinstance(input_videos, tuple):
        input_videos = list(input_videos)

    # Generate embeddings for each frame and concatenate the features
    clip_temp_score_sum, clip_temp_score_cnt = 0, 0
    logit_scale = self.model.logit_scale.exp() if self.logit_scale else 1
    with torch.no_grad():  
        for input_frames in input_videos: # Too many frames in a video, must split before CLIP embedding, limited by the memory
            input_frames = input_frames.to(self.device)
            frame_feature = self.model.get_image_features(input_frames)
            frame_feature = frame_feature / torch.norm(frame_feature, dim=-1, keepdim=True)
            # print(frame_feature.shape)

            clip_temp_score_list = []
            for i in range(frame_feature.shape[0]-1):
                clip_temp_score = logit_scale * frame_feature[i].unsqueeze(0) @ frame_feature[i+1].unsqueeze(0).T
                clip_temp_score = clip_temp_score.item()
                # print(clip_temp_score)
                clip_temp_score_list.append(clip_temp_score)
            clip_temp_cur_avg_score = sum(clip_temp_score_list)/len(clip_temp_score_list)
            clip_temp_score_sum += clip_temp_cur_avg_score
            clip_temp_score_cnt += 1
            print('current clip temp similarity score', clip_temp_cur_avg_score)

    clip_temp_score_avg = clip_temp_score_sum/clip_temp_score_cnt

    result['clip_temp_score'] = clip_temp_score_avg

    self.results.append(result)

DSGScore

Bases: BaseMetric

Initialize the DSGScore evaluator.

Parameters:

Name Type Description Default
vqa_model_name str

The name of the VQA model used in the DSGScore evaluator. Defaults to InstructBLIP, you can also choose the "MPLUG" as the VQA model.

'InstructBLIP'
verbose bool

Whether the intermediate output processes is required. Defaults to False.

False
Source code in aigve/metrics/text_video_alignment/gpt_based/dsg/dsg_eval.py
@METRICS.register_module()
class DSGScore(BaseMetric):
    """ Initialize the ``DSGScore`` evaluator.

    Args:
        vqa_model_name (str): The name of the VQA model used in the DSGScore evaluator. Defaults to ``InstructBLIP``, you can also choose the "MPLUG" as the VQA model.
        verbose (bool): Whether the intermediate output processes is required. Defaults to False.
    """
    def __init__(self, 
                 vqa_model_name: str = "InstructBLIP",
                 verbose: bool = False):
        super().__init__()

        self.submodel_path = 'metrics/text_video_alignment/gpt_based/dsg'
        if not submodule_exists(self.submodel_path):
            add_git_submodule(
                repo_url='https://github.com/j-min/DSG.git', 
                submodule_path=self.submodel_path
            )     
        from .DSG.dsg.vqa_utils import MPLUG, InstructBLIP

        self.vqa_model_name = vqa_model_name
        assert self.vqa_model_name in ["InstructBLIP", "MPLUG"]
        if self.vqa_model_name == 'InstructBLIP':
            self.vqa_model = InstructBLIP()
        else:
            self.vqa_model = MPLUG()

        self.verbose = verbose
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    def evaluate_image_dsg(self, qid_list, frame_index, frame) -> Dict[str, Union[int, dict, float]]:
        """ Evaluate a generated image with DSG evaluator; this is the intermediate process of the ``process`` function. 

        Args:
            qid_list (List[str]): The list of DSG parse question generation results.
            frame_index (int): The index number of the currently evaluated frame.
            frame (List[List[float]]): The current evaluated frame.

        Returns:
            Dict[str, Union[int, dict, float]]: A dictionary containing evaluation results with the following keys:
                - 'frame_index' (int): The index of the evaluated frame.
                - 'qid2tuple' (dict): Mapping of question IDs to tuples.
                - 'qid2dependency' (dict): Mapping of question IDs to dependencies.
                - 'qid2question' (dict): Mapping of question IDs to actual questions.
                - 'qid2answer' (dict): Mapping of question IDs to predicted answers.
                - 'qid2scores' (dict): Mapping of question IDs to scores before dependency filtering.
                - 'qid2validity' (dict): Mapping of question IDs to boolean validity after dependency filtering.
                - 'average_score_with_dependency' (float): Average score considering dependency filtering.
                - 'average_score_without_dependency' (float): Average score before dependency filtering.
        """
        if self.verbose:
            print("#"*50)
            print("2) Answer questions given the generated image, with VQA")
            print("#"*50)

        # 2) answer questions with the generated image
        qid2answer = {}
        qid2scores = {}

        qid2tuple, qid2dependency, qid2question = qid_list
        for id, question in qid2question.items():
            answer = self.vqa_model.vqa(image=frame, question=question)
            print(answer)
            qid2answer[id] = answer
            qid2scores[id] = float('yes' in answer)

        average_score_without_dep = sum(qid2scores.values()) / len(qid2scores)
        print(average_score_without_dep, qid2answer, qid2scores)

        if self.verbose:
            print("#"*50)
            print("3) Zero-out scores from invalid questions")
            print("#"*50)

        # 3) zero-out scores from invalid questions 
        qid2validity = {}
        qid2scores_after_filtering = deepcopy(qid2scores)

        # print('qid2scores', qid2scores)
        # print('qid2dependency', qid2dependency)
        for id, parent_ids in qid2dependency.items():
            # zero-out scores if parent questions are answered 'no'
            any_parent_answered_no = False
            for parent_id in parent_ids:
                parent_id = list(parent_id)[0]
                if parent_id == 0:
                    continue
                if qid2scores[parent_id] == 0:
                    any_parent_answered_no = True
                    break
            if any_parent_answered_no:
                qid2scores_after_filtering[id] = 0.0
                qid2validity[id] = False
            else:
                qid2validity[id] = True

        if self.verbose:
            print("Per-quesiton eval results (after using dependency)")
            for id in qid2question:
                print("ID", id)
                print("question", qid2question[id])
                print("answer", qid2answer[id])
                print("validity", qid2validity[id])
                print("score (before filtering)", qid2scores[id])
                print("score (after filtering)", qid2scores_after_filtering[id])
                print()

        if self.verbose:
            print("#"*50)
            print("4) Calculate the final score by averaging")
            print("#"*50)

        average_score_with_dep = sum(qid2scores_after_filtering.values()) / len(qid2scores)

        return {
            'frame_index': frame_index,
            'qid2tuple': qid2tuple,
            'qid2dependency': qid2dependency,
            'qid2question': qid2question,
            'qid2answer': qid2answer,
            'qid2scores': qid2scores,
            'qid2validity': qid2validity,
            'average_score_with_dependency': average_score_with_dep,
            'average_score_without_dependency': average_score_without_dep
        }


    def process(self, data_batch: Sequence, data_samples: Sequence) -> None:
        """DSGScore process

        Process one batch of data samples and predictions. The processed
        results should be stored in ``self.results``, which will be used to
        compute the metrics when all batches have been processed.

        Args:
            data_batch (Sequence): A batch of data from the dataloader.
            data_samples (Sequence): A batch of data samples that
                contain annotations and predictions.
        """

        result = dict()

        input_qid_lists, input_videos = data_samples
        bsz = len(input_qid_lists)
        # print('input_qid_lists: ', input_qid_lists)

        # Ensure prompt_input is a tensor
        if isinstance(input_qid_lists, tuple):
            input_qid_lists = list(input_qid_lists)

        if isinstance(input_videos, tuple):
            input_videos = list(input_videos)

        average_dep_score_list, average_wo_dep_score_list = [], []
        for input_qid_list, input_video in zip([input_qid_lists], input_videos):
            evaluate_dict_list = []
            dep_score, wo_dep_score = [], []
            for index, frame in enumerate(input_video):
                # print('input_qid_list: ', input_qid_list)
                evaluate_dict = self.evaluate_image_dsg(qid_list=input_qid_list, 
                                                        frame_index=index, 
                                                        frame=frame)
                evaluate_dict_list.append(evaluate_dict)
                frame_average_score_with_dependency = evaluate_dict['average_score_with_dependency']
                dep_score.append(frame_average_score_with_dependency)
                frame_average_score_without_dependency = evaluate_dict['average_score_without_dependency']
                wo_dep_score.append(frame_average_score_without_dependency)
            avg_dep_score, avg_wo_dep_score = sum(dep_score)/len(dep_score), sum(wo_dep_score)/len(dep_score)
            average_dep_score_list.append(avg_dep_score)
            average_wo_dep_score_list.append(avg_wo_dep_score)


        result['average_dep_dgs_score'] = sum(average_dep_score_list)/len(average_dep_score_list)
        result['average_wo_dep_dgs_score'] = sum(average_wo_dep_score_list)/len(average_wo_dep_score_list)

        self.results.append(result)


    def compute_metrics(self, results: list) -> Dict[str, float]:
        """Compute the metrics from processed results.

        Args:
            results (list): The processed results of each batch.

        Returns:
            Dict[str, float]: The computed metrics. The keys are the names of
            the metrics, and the values are corresponding results.
        """
        logger: MMLogger = MMLogger.get_current_instance()

        dep_dsg_score_np = np.zeros(len(results))
        wo_dep_dsg_score_np = np.zeros(len(results))
        for i, result in enumerate(results):
            dep_dsg_score_np[i] = result['average_dep_dgs_score']
            wo_dep_dsg_score_np[i] = result['average_wo_dep_dgs_score']

        dep_dsg_score_np_mean = np.mean(dep_dsg_score_np) 
        wo_dep_dsg_score_np_mean = np.mean(wo_dep_dsg_score_np)

        print("Test results: dsg score with dependency={:.4f}"
              .format(dep_dsg_score_np_mean))
        print("Test results: dsg score without dependency={:.4f}"
              .format(wo_dep_dsg_score_np_mean))

        return result

compute_metrics(results)

Compute the metrics from processed results.

Parameters:

Name Type Description Default
results list

The processed results of each batch.

required

Returns:

Type Description
Dict[str, float]

Dict[str, float]: The computed metrics. The keys are the names of

Dict[str, float]

the metrics, and the values are corresponding results.

Source code in aigve/metrics/text_video_alignment/gpt_based/dsg/dsg_eval.py
def compute_metrics(self, results: list) -> Dict[str, float]:
    """Compute the metrics from processed results.

    Args:
        results (list): The processed results of each batch.

    Returns:
        Dict[str, float]: The computed metrics. The keys are the names of
        the metrics, and the values are corresponding results.
    """
    logger: MMLogger = MMLogger.get_current_instance()

    dep_dsg_score_np = np.zeros(len(results))
    wo_dep_dsg_score_np = np.zeros(len(results))
    for i, result in enumerate(results):
        dep_dsg_score_np[i] = result['average_dep_dgs_score']
        wo_dep_dsg_score_np[i] = result['average_wo_dep_dgs_score']

    dep_dsg_score_np_mean = np.mean(dep_dsg_score_np) 
    wo_dep_dsg_score_np_mean = np.mean(wo_dep_dsg_score_np)

    print("Test results: dsg score with dependency={:.4f}"
          .format(dep_dsg_score_np_mean))
    print("Test results: dsg score without dependency={:.4f}"
          .format(wo_dep_dsg_score_np_mean))

    return result

evaluate_image_dsg(qid_list, frame_index, frame)

Evaluate a generated image with DSG evaluator; this is the intermediate process of the process function.

Parameters:

Name Type Description Default
qid_list List[str]

The list of DSG parse question generation results.

required
frame_index int

The index number of the currently evaluated frame.

required
frame List[List[float]]

The current evaluated frame.

required

Returns:

Type Description
Dict[str, Union[int, dict, float]]

Dict[str, Union[int, dict, float]]: A dictionary containing evaluation results with the following keys: - 'frame_index' (int): The index of the evaluated frame. - 'qid2tuple' (dict): Mapping of question IDs to tuples. - 'qid2dependency' (dict): Mapping of question IDs to dependencies. - 'qid2question' (dict): Mapping of question IDs to actual questions. - 'qid2answer' (dict): Mapping of question IDs to predicted answers. - 'qid2scores' (dict): Mapping of question IDs to scores before dependency filtering. - 'qid2validity' (dict): Mapping of question IDs to boolean validity after dependency filtering. - 'average_score_with_dependency' (float): Average score considering dependency filtering. - 'average_score_without_dependency' (float): Average score before dependency filtering.

Source code in aigve/metrics/text_video_alignment/gpt_based/dsg/dsg_eval.py
def evaluate_image_dsg(self, qid_list, frame_index, frame) -> Dict[str, Union[int, dict, float]]:
    """ Evaluate a generated image with DSG evaluator; this is the intermediate process of the ``process`` function. 

    Args:
        qid_list (List[str]): The list of DSG parse question generation results.
        frame_index (int): The index number of the currently evaluated frame.
        frame (List[List[float]]): The current evaluated frame.

    Returns:
        Dict[str, Union[int, dict, float]]: A dictionary containing evaluation results with the following keys:
            - 'frame_index' (int): The index of the evaluated frame.
            - 'qid2tuple' (dict): Mapping of question IDs to tuples.
            - 'qid2dependency' (dict): Mapping of question IDs to dependencies.
            - 'qid2question' (dict): Mapping of question IDs to actual questions.
            - 'qid2answer' (dict): Mapping of question IDs to predicted answers.
            - 'qid2scores' (dict): Mapping of question IDs to scores before dependency filtering.
            - 'qid2validity' (dict): Mapping of question IDs to boolean validity after dependency filtering.
            - 'average_score_with_dependency' (float): Average score considering dependency filtering.
            - 'average_score_without_dependency' (float): Average score before dependency filtering.
    """
    if self.verbose:
        print("#"*50)
        print("2) Answer questions given the generated image, with VQA")
        print("#"*50)

    # 2) answer questions with the generated image
    qid2answer = {}
    qid2scores = {}

    qid2tuple, qid2dependency, qid2question = qid_list
    for id, question in qid2question.items():
        answer = self.vqa_model.vqa(image=frame, question=question)
        print(answer)
        qid2answer[id] = answer
        qid2scores[id] = float('yes' in answer)

    average_score_without_dep = sum(qid2scores.values()) / len(qid2scores)
    print(average_score_without_dep, qid2answer, qid2scores)

    if self.verbose:
        print("#"*50)
        print("3) Zero-out scores from invalid questions")
        print("#"*50)

    # 3) zero-out scores from invalid questions 
    qid2validity = {}
    qid2scores_after_filtering = deepcopy(qid2scores)

    # print('qid2scores', qid2scores)
    # print('qid2dependency', qid2dependency)
    for id, parent_ids in qid2dependency.items():
        # zero-out scores if parent questions are answered 'no'
        any_parent_answered_no = False
        for parent_id in parent_ids:
            parent_id = list(parent_id)[0]
            if parent_id == 0:
                continue
            if qid2scores[parent_id] == 0:
                any_parent_answered_no = True
                break
        if any_parent_answered_no:
            qid2scores_after_filtering[id] = 0.0
            qid2validity[id] = False
        else:
            qid2validity[id] = True

    if self.verbose:
        print("Per-quesiton eval results (after using dependency)")
        for id in qid2question:
            print("ID", id)
            print("question", qid2question[id])
            print("answer", qid2answer[id])
            print("validity", qid2validity[id])
            print("score (before filtering)", qid2scores[id])
            print("score (after filtering)", qid2scores_after_filtering[id])
            print()

    if self.verbose:
        print("#"*50)
        print("4) Calculate the final score by averaging")
        print("#"*50)

    average_score_with_dep = sum(qid2scores_after_filtering.values()) / len(qid2scores)

    return {
        'frame_index': frame_index,
        'qid2tuple': qid2tuple,
        'qid2dependency': qid2dependency,
        'qid2question': qid2question,
        'qid2answer': qid2answer,
        'qid2scores': qid2scores,
        'qid2validity': qid2validity,
        'average_score_with_dependency': average_score_with_dep,
        'average_score_without_dependency': average_score_without_dep
    }

process(data_batch, data_samples)

DSGScore process

Process one batch of data samples and predictions. The processed results should be stored in self.results, which will be used to compute the metrics when all batches have been processed.

Parameters:

Name Type Description Default
data_batch Sequence

A batch of data from the dataloader.

required
data_samples Sequence

A batch of data samples that contain annotations and predictions.

required
Source code in aigve/metrics/text_video_alignment/gpt_based/dsg/dsg_eval.py
def process(self, data_batch: Sequence, data_samples: Sequence) -> None:
    """DSGScore process

    Process one batch of data samples and predictions. The processed
    results should be stored in ``self.results``, which will be used to
    compute the metrics when all batches have been processed.

    Args:
        data_batch (Sequence): A batch of data from the dataloader.
        data_samples (Sequence): A batch of data samples that
            contain annotations and predictions.
    """

    result = dict()

    input_qid_lists, input_videos = data_samples
    bsz = len(input_qid_lists)
    # print('input_qid_lists: ', input_qid_lists)

    # Ensure prompt_input is a tensor
    if isinstance(input_qid_lists, tuple):
        input_qid_lists = list(input_qid_lists)

    if isinstance(input_videos, tuple):
        input_videos = list(input_videos)

    average_dep_score_list, average_wo_dep_score_list = [], []
    for input_qid_list, input_video in zip([input_qid_lists], input_videos):
        evaluate_dict_list = []
        dep_score, wo_dep_score = [], []
        for index, frame in enumerate(input_video):
            # print('input_qid_list: ', input_qid_list)
            evaluate_dict = self.evaluate_image_dsg(qid_list=input_qid_list, 
                                                    frame_index=index, 
                                                    frame=frame)
            evaluate_dict_list.append(evaluate_dict)
            frame_average_score_with_dependency = evaluate_dict['average_score_with_dependency']
            dep_score.append(frame_average_score_with_dependency)
            frame_average_score_without_dependency = evaluate_dict['average_score_without_dependency']
            wo_dep_score.append(frame_average_score_without_dependency)
        avg_dep_score, avg_wo_dep_score = sum(dep_score)/len(dep_score), sum(wo_dep_score)/len(dep_score)
        average_dep_score_list.append(avg_dep_score)
        average_wo_dep_score_list.append(avg_wo_dep_score)


    result['average_dep_dgs_score'] = sum(average_dep_score_list)/len(average_dep_score_list)
    result['average_wo_dep_dgs_score'] = sum(average_wo_dep_score_list)/len(average_wo_dep_score_list)

    self.results.append(result)

FVDScore

Bases: BaseMetric

Fréchet Video Distance (FVD) computation using I3D model. Users can first download the pretrained I3D model from: https://github.com/hassony2/kinetics_i3d_pytorch/blob/master/model/model_rgb.pth Then put in the folder: AIGVE_Tool/aigve/metrics/video_quality_assessment/distribution_based/fvd/

Parameters:

Name Type Description Default
model_path str

Path to pre-trained I3D model.

required
feature_layer int

Layer to extract features from. Default is -2 (penultimate layer).

-2
is_gpu bool

Whether to use GPU. Default is True.

True
Source code in aigve/metrics/video_quality_assessment/distribution_based/fvd/fvd_metric.py
@METRICS.register_module()
class FVDScore(BaseMetric):
    """
    Fréchet Video Distance (FVD) computation using I3D model.
    Users can first download the pretrained I3D model from: 
    https://github.com/hassony2/kinetics_i3d_pytorch/blob/master/model/model_rgb.pth
    Then put in the folder: 
    AIGVE_Tool/aigve/metrics/video_quality_assessment/distribution_based/fvd/

    Args:
        model_path (str): Path to pre-trained I3D model.
        feature_layer (int): Layer to extract features from. Default is -2 (penultimate layer).
        is_gpu (bool): Whether to use GPU. Default is True.
    """
    def __init__(self, 
                 model_path: str, 
                 feature_layer: int = -2, 
                 is_gpu: bool = True):
        super(FVDScore, self).__init__()
        self.device = torch.device("cuda" if is_gpu and torch.cuda.is_available() else "cpu")
        self.model = self.load_i3d_model(model_path, feature_layer)
        self.model.eval()

        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),  # I3D input size
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
        ])

    def load_i3d_model(self, model_path: str, feature_layer: int) -> torch.nn.Module:
        """
        Load a pre-trained I3D model and modify it to extract features.

        Args:
            model_path (str): Path to the I3D model checkpoint.
            feature_layer (int): The layer index from which to extract features.

        Returns:
            torch.nn.Module: I3D feature extraction model.
        """
        model = models.video.r3d_18(pretrained=True)  # Using ResNet3D as an I3D alternative
        model.fc = nn.Identity()  # Remove classification head

        if os.path.exists(model_path):
            model.load_state_dict(torch.load(model_path, map_location=self.device))
        else:
            print(f"Warning: Model checkpoint not found at {model_path}, using default weights.")

        return model

    def preprocess_tensor(self, video_tensor: torch.Tensor) -> torch.Tensor:
        """
        Resize and normalize a video tensor.

        Args:
            video_tensor (torch.Tensor): Tensor of shape [T, C, H, W].

        Returns:
            torch.Tensor: Preprocessed tensor of shape [T, C, H, W].
        """
        return self.transform(video_tensor / 255.0)

    def calculate_statistics(self, video_tensor: torch.Tensor) -> tuple[np.ndarray, np.ndarray]:
        """
        Extract activation statistics from video frames.

        Args:
            video_tensor (torch.Tensor): Video tensor [T, C, H, W].

        Returns:
            Tuple[np.ndarray, np.ndarray]: Mean and covariance of extracted features.
        """
        video_tensor = self.preprocess_tensor(video_tensor).to(self.device)
        self.model.to(self.device)
        # Permute to match I3D input format [B, C, T, H, W]
        video_tensor = video_tensor.permute(1, 0, 2, 3).unsqueeze(0)  # Shape: [1, 3, T, H, W]
        with torch.no_grad():
            features = self.model(video_tensor).cpu().numpy()

        # print('features: ', features.shape)
        mu = features.mean(axis=0)
        # Ensure at least 2 samples to compute covariance
        if features.shape[0] > 1:
            sigma = np.cov(features, rowvar=False)
        else:
            sigma = np.zeros((features.shape[1], features.shape[1])) # Identity fallback
        return mu, sigma

    def calculate_fvd(self, real: torch.Tensor, fake: torch.Tensor) -> float:
        """
        Compute FVD score between real and generated videos.

        Args:
            real (torch.Tensor): Real video tensor [T, C, H, W].
            fake (torch.Tensor): Generated video tensor [T, C, H, W].

        Returns:
            float: FVD score.
        """
        mu1, sigma1 = self.calculate_statistics(real) # Shape[512], Shape[512, 512]
        mu2, sigma2 = self.calculate_statistics(fake)
        # print(f"mu1 shape: {mu1.shape}, sigma1 shape: {sigma1.shape}")
        # print(f"mu2 shape: {mu2.shape}, sigma2 shape: {sigma2.shape}")

        # Ensure sigma matrices are at least 2D
        if sigma1.ndim < 2:
            sigma1 = np.expand_dims(sigma1, axis=0)
        if sigma2.ndim < 2:
            sigma2 = np.expand_dims(sigma2, axis=0)

        ssdiff = np.sum((mu1 - mu2) ** 2.0)
        covmean = sqrtm(sigma1 @ sigma2)

        # Check and correct for imaginary numbers
        if np.iscomplexobj(covmean):
            covmean = covmean.real

        return ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)

    def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
        """
        Process a batch of videos and compute FVD.

        Args:
            data_batch (dict): Not used here.
            data_samples (List[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[str], Tuple[str]]):
                A list containing four tuples:
                - A tuple of `real_tensor` (torch.Tensor): Real video tensor [T, C, H, W].
                - A tuple of `gen_tensor` (torch.Tensor): Generated video tensor [T, C, H, W].
                - A tuple of `real_video_name` (str): Ground-truth video filename.
                - A tuple of `gen_video_name` (str): Generated video filename.
                The len of each tuples are the batch size.
        """
        results = []
        real_tensor_tuple, gen_tensor_tuple, real_video_name_tuple, gen_video_name_tuple = data_samples

        batch_size = len(real_tensor_tuple)
        with torch.no_grad():
            for i in range(batch_size):
                real_video_name = real_video_name_tuple[i]
                gen_video_name = gen_video_name_tuple[i]
                real_tensor = real_tensor_tuple[i]
                gen_tensor = gen_tensor_tuple[i]

                fvd_score = self.calculate_fvd(real_tensor, gen_tensor)

                results.append({
                    "Real video_name": real_video_name, 
                    "Generated video_name": gen_video_name, 
                    "FVD_Score": fvd_score
                })
                print(f"Processed FVD score {fvd_score:.4f} between {real_video_name} and {gen_video_name}")

        self.results.extend(results)

    def compute_metrics(self, results: list) -> Dict[str, float]:
        """
        Compute the final FVD score.

        Args:
            results (list): List of FVD scores for each batch.

        Returns:
            Dict[str, float]: Dictionary containing mean FVD score.
        """
        scores = np.array([res["FVD_Score"] for res in self.results])
        mean_score = np.mean(scores) if scores.size > 0 else 0.0
        print(f"FVD mean score: {mean_score:.4f}")

        json_file_path = os.path.join(os.getcwd(), "fvd_results.json")
        final_results = {
            "video_results": self.results, 
            "FVD_Mean_Score": mean_score
        }
        with open(json_file_path, "w") as json_file:
            json.dump(final_results, json_file, indent=4)
        print(f"FVD mean score saved to {json_file_path}")

        return {"FVD_Mean_Score": mean_score}

calculate_fvd(real, fake)

Compute FVD score between real and generated videos.

Parameters:

Name Type Description Default
real Tensor

Real video tensor [T, C, H, W].

required
fake Tensor

Generated video tensor [T, C, H, W].

required

Returns:

Name Type Description
float float

FVD score.

Source code in aigve/metrics/video_quality_assessment/distribution_based/fvd/fvd_metric.py
def calculate_fvd(self, real: torch.Tensor, fake: torch.Tensor) -> float:
    """
    Compute FVD score between real and generated videos.

    Args:
        real (torch.Tensor): Real video tensor [T, C, H, W].
        fake (torch.Tensor): Generated video tensor [T, C, H, W].

    Returns:
        float: FVD score.
    """
    mu1, sigma1 = self.calculate_statistics(real) # Shape[512], Shape[512, 512]
    mu2, sigma2 = self.calculate_statistics(fake)
    # print(f"mu1 shape: {mu1.shape}, sigma1 shape: {sigma1.shape}")
    # print(f"mu2 shape: {mu2.shape}, sigma2 shape: {sigma2.shape}")

    # Ensure sigma matrices are at least 2D
    if sigma1.ndim < 2:
        sigma1 = np.expand_dims(sigma1, axis=0)
    if sigma2.ndim < 2:
        sigma2 = np.expand_dims(sigma2, axis=0)

    ssdiff = np.sum((mu1 - mu2) ** 2.0)
    covmean = sqrtm(sigma1 @ sigma2)

    # Check and correct for imaginary numbers
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    return ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)

calculate_statistics(video_tensor)

Extract activation statistics from video frames.

Parameters:

Name Type Description Default
video_tensor Tensor

Video tensor [T, C, H, W].

required

Returns:

Type Description
tuple[ndarray, ndarray]

Tuple[np.ndarray, np.ndarray]: Mean and covariance of extracted features.

Source code in aigve/metrics/video_quality_assessment/distribution_based/fvd/fvd_metric.py
def calculate_statistics(self, video_tensor: torch.Tensor) -> tuple[np.ndarray, np.ndarray]:
    """
    Extract activation statistics from video frames.

    Args:
        video_tensor (torch.Tensor): Video tensor [T, C, H, W].

    Returns:
        Tuple[np.ndarray, np.ndarray]: Mean and covariance of extracted features.
    """
    video_tensor = self.preprocess_tensor(video_tensor).to(self.device)
    self.model.to(self.device)
    # Permute to match I3D input format [B, C, T, H, W]
    video_tensor = video_tensor.permute(1, 0, 2, 3).unsqueeze(0)  # Shape: [1, 3, T, H, W]
    with torch.no_grad():
        features = self.model(video_tensor).cpu().numpy()

    # print('features: ', features.shape)
    mu = features.mean(axis=0)
    # Ensure at least 2 samples to compute covariance
    if features.shape[0] > 1:
        sigma = np.cov(features, rowvar=False)
    else:
        sigma = np.zeros((features.shape[1], features.shape[1])) # Identity fallback
    return mu, sigma

compute_metrics(results)

Compute the final FVD score.

Parameters:

Name Type Description Default
results list

List of FVD scores for each batch.

required

Returns:

Type Description
Dict[str, float]

Dict[str, float]: Dictionary containing mean FVD score.

Source code in aigve/metrics/video_quality_assessment/distribution_based/fvd/fvd_metric.py
def compute_metrics(self, results: list) -> Dict[str, float]:
    """
    Compute the final FVD score.

    Args:
        results (list): List of FVD scores for each batch.

    Returns:
        Dict[str, float]: Dictionary containing mean FVD score.
    """
    scores = np.array([res["FVD_Score"] for res in self.results])
    mean_score = np.mean(scores) if scores.size > 0 else 0.0
    print(f"FVD mean score: {mean_score:.4f}")

    json_file_path = os.path.join(os.getcwd(), "fvd_results.json")
    final_results = {
        "video_results": self.results, 
        "FVD_Mean_Score": mean_score
    }
    with open(json_file_path, "w") as json_file:
        json.dump(final_results, json_file, indent=4)
    print(f"FVD mean score saved to {json_file_path}")

    return {"FVD_Mean_Score": mean_score}

load_i3d_model(model_path, feature_layer)

Load a pre-trained I3D model and modify it to extract features.

Parameters:

Name Type Description Default
model_path str

Path to the I3D model checkpoint.

required
feature_layer int

The layer index from which to extract features.

required

Returns:

Type Description
Module

torch.nn.Module: I3D feature extraction model.

Source code in aigve/metrics/video_quality_assessment/distribution_based/fvd/fvd_metric.py
def load_i3d_model(self, model_path: str, feature_layer: int) -> torch.nn.Module:
    """
    Load a pre-trained I3D model and modify it to extract features.

    Args:
        model_path (str): Path to the I3D model checkpoint.
        feature_layer (int): The layer index from which to extract features.

    Returns:
        torch.nn.Module: I3D feature extraction model.
    """
    model = models.video.r3d_18(pretrained=True)  # Using ResNet3D as an I3D alternative
    model.fc = nn.Identity()  # Remove classification head

    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path, map_location=self.device))
    else:
        print(f"Warning: Model checkpoint not found at {model_path}, using default weights.")

    return model

preprocess_tensor(video_tensor)

Resize and normalize a video tensor.

Parameters:

Name Type Description Default
video_tensor Tensor

Tensor of shape [T, C, H, W].

required

Returns:

Type Description
Tensor

torch.Tensor: Preprocessed tensor of shape [T, C, H, W].

Source code in aigve/metrics/video_quality_assessment/distribution_based/fvd/fvd_metric.py
def preprocess_tensor(self, video_tensor: torch.Tensor) -> torch.Tensor:
    """
    Resize and normalize a video tensor.

    Args:
        video_tensor (torch.Tensor): Tensor of shape [T, C, H, W].

    Returns:
        torch.Tensor: Preprocessed tensor of shape [T, C, H, W].
    """
    return self.transform(video_tensor / 255.0)

process(data_batch, data_samples)

Process a batch of videos and compute FVD.

Parameters:

Name Type Description Default
data_batch dict

Not used here.

required
data_samples List[Tuple[Tensor], Tuple[Tensor], Tuple[str], Tuple[str]]

A list containing four tuples: - A tuple of real_tensor (torch.Tensor): Real video tensor [T, C, H, W]. - A tuple of gen_tensor (torch.Tensor): Generated video tensor [T, C, H, W]. - A tuple of real_video_name (str): Ground-truth video filename. - A tuple of gen_video_name (str): Generated video filename. The len of each tuples are the batch size.

required
Source code in aigve/metrics/video_quality_assessment/distribution_based/fvd/fvd_metric.py
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
    """
    Process a batch of videos and compute FVD.

    Args:
        data_batch (dict): Not used here.
        data_samples (List[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[str], Tuple[str]]):
            A list containing four tuples:
            - A tuple of `real_tensor` (torch.Tensor): Real video tensor [T, C, H, W].
            - A tuple of `gen_tensor` (torch.Tensor): Generated video tensor [T, C, H, W].
            - A tuple of `real_video_name` (str): Ground-truth video filename.
            - A tuple of `gen_video_name` (str): Generated video filename.
            The len of each tuples are the batch size.
    """
    results = []
    real_tensor_tuple, gen_tensor_tuple, real_video_name_tuple, gen_video_name_tuple = data_samples

    batch_size = len(real_tensor_tuple)
    with torch.no_grad():
        for i in range(batch_size):
            real_video_name = real_video_name_tuple[i]
            gen_video_name = gen_video_name_tuple[i]
            real_tensor = real_tensor_tuple[i]
            gen_tensor = gen_tensor_tuple[i]

            fvd_score = self.calculate_fvd(real_tensor, gen_tensor)

            results.append({
                "Real video_name": real_video_name, 
                "Generated video_name": gen_video_name, 
                "FVD_Score": fvd_score
            })
            print(f"Processed FVD score {fvd_score:.4f} between {real_video_name} and {gen_video_name}")

    self.results.extend(results)

FidDataset

Bases: Dataset

Dataset for Fréchet Inception Distance (FID) evaluation.

For each sample, this dataset: - Loads both the ground-truth (real) and generated (predicted) videos. - Converts each video into a tensor of shape [T, C, H, W] using OpenCV. - Optionally pads or truncates videos to a fixed number of frames.

Parameters:

Name Type Description Default
video_dir str

Directory containing video files.

required
prompt_dir str

Path to JSON file that lists ground-truth and generated video filenames.

required
max_len int

Maximum number of frames to load per video. Default: 500.

500
if_pad bool

Whether to pad videos to exactly max_len frames. If False, videos can have variable lengths.

False
Source code in aigve/datasets/fid_dataset.py
@DATASETS.register_module()
class FidDataset(Dataset):
    """
    Dataset for Fréchet Inception Distance (FID) evaluation.

    For each sample, this dataset:
        - Loads both the ground-truth (real) and generated (predicted) videos.
        - Converts each video into a tensor of shape [T, C, H, W] using OpenCV.
        - Optionally pads or truncates videos to a fixed number of frames.

    Args:
        video_dir (str): Directory containing video files.
        prompt_dir (str): Path to JSON file that lists ground-truth and generated video filenames.
        max_len (int): Maximum number of frames to load per video. Default: 500.
        if_pad (bool): Whether to pad videos to exactly `max_len` frames. If False, videos can have variable lengths.
    """

    def __init__(self, video_dir, prompt_dir, max_len=500, if_pad=False):
        super(FidDataset, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.video_dir = video_dir
        self.prompt_dir = prompt_dir
        self.max_len = max_len
        self.if_pad = if_pad

        self.gt_video_names, self.gen_video_names = self._read_video_names()

    def _read_video_names(self):
        """Reads video names from the dataset JSON file."""
        with open(self.prompt_dir, 'r') as reader:
            read_data = json.load(reader)
            gt_video_names = [item['video_path_gt'].strip() for item in read_data["data_list"]]
            gen_video_names = [item['video_path_pd'].strip() for item in read_data["data_list"]]
        return gt_video_names, gen_video_names

    def _load_video_tensor(self, video_path: str) -> torch.Tensor:
        """Load a video and return its tensor of shape [T, C, H, W]."""
        assert os.path.exists(video_path), f"Video file not found: {video_path}"
        cap = cv2.VideoCapture(video_path)
        input_frames = []
        frame_count = 0
        while cap.isOpened() and frame_count < self.max_len:
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            input_frames.append(torch.tensor(frame).float())
            frame_count += 1

        cap.release()

        if len(input_frames) == 0:
            raise RuntimeError(f"No valid frames found in {video_path}")

        if self.if_pad:
            num_frames = len(input_frames)
            if num_frames < self.max_len:
                pad_frames = torch.zeros((self.max_len - num_frames, *input_frames[0].shape))
                video_tensor = torch.cat((torch.stack(input_frames), pad_frames), dim=0)
            else:
                video_tensor = torch.stack(input_frames[:self.max_len])
        else:
            video_tensor = torch.stack(input_frames)

        # Convert from [T, H, W, C] to [T, C, H, W]
        return video_tensor.permute(0, 3, 1, 2)

    def __len__(self):
        return len(self.gt_video_names)

    def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor, str, str]:
        """
        Returns:
            Tuple[torch.Tensor, torch.Tensor, str, str]: 
                - Ground-truth (Real) video tensor of shape [T, C, H, W].
                - Generated video tensor of shape [T, C, H, W].
                - Ground-truth video name.
                - Generated video name.
        """
        gt_video_name = self.gt_video_names[index]
        gt_video_path = os.path.join(self.video_dir, gt_video_name)
        gen_video_name = self.gen_video_names[index]
        gen_video_path = os.path.join(self.video_dir, gen_video_name) 

        gt_video_tensor = self._load_video_tensor(gt_video_path)
        gen_video_tensor = self._load_video_tensor(gen_video_path)

        return gt_video_tensor, gen_video_tensor, gt_video_name, gen_video_name

__getitem__(index)

Returns:

Type Description
tuple[Tensor, Tensor, str, str]

Tuple[torch.Tensor, torch.Tensor, str, str]: - Ground-truth (Real) video tensor of shape [T, C, H, W]. - Generated video tensor of shape [T, C, H, W]. - Ground-truth video name. - Generated video name.

Source code in aigve/datasets/fid_dataset.py
def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor, str, str]:
    """
    Returns:
        Tuple[torch.Tensor, torch.Tensor, str, str]: 
            - Ground-truth (Real) video tensor of shape [T, C, H, W].
            - Generated video tensor of shape [T, C, H, W].
            - Ground-truth video name.
            - Generated video name.
    """
    gt_video_name = self.gt_video_names[index]
    gt_video_path = os.path.join(self.video_dir, gt_video_name)
    gen_video_name = self.gen_video_names[index]
    gen_video_path = os.path.join(self.video_dir, gen_video_name) 

    gt_video_tensor = self._load_video_tensor(gt_video_path)
    gen_video_tensor = self._load_video_tensor(gen_video_path)

    return gt_video_tensor, gen_video_tensor, gt_video_name, gen_video_name

GSTVQADataset

Bases: Dataset

Dataset for GSTVQA metric, supports feature extraction using VGG16 or ResNet.

Source code in aigve/datasets/gstvqa_dataset.py
@DATASETS.register_module()
class GSTVQADataset(Dataset):
    """Dataset for GSTVQA metric, supports feature extraction using VGG16 or ResNet."""

    def __init__(self, video_dir, prompt_dir, model_name='vgg16', max_len=500):
        super(GSTVQADataset, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.video_dir = video_dir
        self.prompt_dir = prompt_dir
        self.model_name = model_name
        self.max_len = max_len
        self.feature_extractor = FeatureExtractor(model_name=model_name)

        self.prompts, self.video_names = self._read_prompt_videoname()

    def _read_prompt_videoname(self):
        with open(self.prompt_dir, 'r') as reader:
            read_data = json.load(reader)

        prompt_data_list, video_name_list = [], []
        for item in read_data["data_list"]:
            prompt = item['prompt_gt'].strip()
            video_name = item['video_path_pd'].strip()
            prompt_data_list.append(prompt)
            video_name_list.append(video_name)

        return prompt_data_list, video_name_list

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, index) -> tuple[torch.Tensor, int, str]:
        """
        Returns a tuple of:
            deep_features (torch.Tensor): Shape [max_len, 2944]
                Mean and std features extracted from input frames using the chosen model (VGG16 or ResNet).
                Padded to self.max_len if the number of frames is less.
            num_frames (int): The number of frames in the video.
            video_name (str): The file name for the video.
        """
        video_name = self.video_names[index]
        video_path = os.path.join(self.video_dir, video_name)
        input_frames = []

        cap = cv2.VideoCapture(video_path)
        frame_count = 0

        while cap.isOpened() and frame_count < self.max_len:
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # frame = cv2.resize(frame, self.frame_size)
            input_frames.append(torch.tensor(frame).float())
            frame_count += 1

        cap.release()

        # Pad or truncate frames to max_len
        num_frames = len(input_frames)
        # print('num_frames: ', num_frames)
        if num_frames < 30:
            pad_frames = torch.zeros((30 - num_frames, *input_frames[0].shape))
            input_frames_tensor = torch.cat((torch.stack(input_frames), pad_frames), dim=0)
            num_frames = 30 # Force min frames to be 30 (since two att_frams=15(kernel_size) used in GSTVQA)
        elif num_frames < self.max_len:
            pad_frames = torch.zeros((self.max_len - num_frames, *input_frames[0].shape))
            input_frames_tensor = torch.cat((torch.stack(input_frames), pad_frames), dim=0)
        else:
            input_frames_tensor = torch.stack(input_frames[:self.max_len])
        # print('input_frames_tensor: ', input_frames_tensor.shape) # shape: toy data [max_len, H(512), W(512), C(3)]

        # Convert from [T, H, W, C] to [T, C, H, W]
        input_frames_tensor = input_frames_tensor.permute(0, 3, 1, 2) 

        # Extract features using the chosen model (VGG16 or ResNet)
        with torch.no_grad():
            mean_features, std_features = self.feature_extractor(input_frames_tensor) # Shape: [T, 1472]: [10, 1472]

        # Concatenate to match GSTVQA expected 2944-dim features
        deep_features = torch.cat((mean_features, std_features), dim=1)  # Shape: [T, 2944]

        # Ensure output shape [max_len, 2944] (pad if needed)
        if deep_features.shape[0] < self.max_len:
            pad_size = self.max_len - deep_features.shape[0]
            padding = torch.zeros((pad_size, 2944), device=deep_features.device)
            deep_features = torch.cat((deep_features, padding), dim=0)

        return deep_features, num_frames, video_name

__getitem__(index)

Returns a tuple of

deep_features (torch.Tensor): Shape [max_len, 2944] Mean and std features extracted from input frames using the chosen model (VGG16 or ResNet). Padded to self.max_len if the number of frames is less. num_frames (int): The number of frames in the video. video_name (str): The file name for the video.

Source code in aigve/datasets/gstvqa_dataset.py
def __getitem__(self, index) -> tuple[torch.Tensor, int, str]:
    """
    Returns a tuple of:
        deep_features (torch.Tensor): Shape [max_len, 2944]
            Mean and std features extracted from input frames using the chosen model (VGG16 or ResNet).
            Padded to self.max_len if the number of frames is less.
        num_frames (int): The number of frames in the video.
        video_name (str): The file name for the video.
    """
    video_name = self.video_names[index]
    video_path = os.path.join(self.video_dir, video_name)
    input_frames = []

    cap = cv2.VideoCapture(video_path)
    frame_count = 0

    while cap.isOpened() and frame_count < self.max_len:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        # frame = cv2.resize(frame, self.frame_size)
        input_frames.append(torch.tensor(frame).float())
        frame_count += 1

    cap.release()

    # Pad or truncate frames to max_len
    num_frames = len(input_frames)
    # print('num_frames: ', num_frames)
    if num_frames < 30:
        pad_frames = torch.zeros((30 - num_frames, *input_frames[0].shape))
        input_frames_tensor = torch.cat((torch.stack(input_frames), pad_frames), dim=0)
        num_frames = 30 # Force min frames to be 30 (since two att_frams=15(kernel_size) used in GSTVQA)
    elif num_frames < self.max_len:
        pad_frames = torch.zeros((self.max_len - num_frames, *input_frames[0].shape))
        input_frames_tensor = torch.cat((torch.stack(input_frames), pad_frames), dim=0)
    else:
        input_frames_tensor = torch.stack(input_frames[:self.max_len])
    # print('input_frames_tensor: ', input_frames_tensor.shape) # shape: toy data [max_len, H(512), W(512), C(3)]

    # Convert from [T, H, W, C] to [T, C, H, W]
    input_frames_tensor = input_frames_tensor.permute(0, 3, 1, 2) 

    # Extract features using the chosen model (VGG16 or ResNet)
    with torch.no_grad():
        mean_features, std_features = self.feature_extractor(input_frames_tensor) # Shape: [T, 1472]: [10, 1472]

    # Concatenate to match GSTVQA expected 2944-dim features
    deep_features = torch.cat((mean_features, std_features), dim=1)  # Shape: [T, 2944]

    # Ensure output shape [max_len, 2944] (pad if needed)
    if deep_features.shape[0] < self.max_len:
        pad_size = self.max_len - deep_features.shape[0]
        padding = torch.zeros((pad_size, 2944), device=deep_features.device)
        deep_features = torch.cat((deep_features, padding), dim=0)

    return deep_features, num_frames, video_name

GstVqa

Bases: BaseMetric

GstVQA metric modified for the toy dataset. (Supporting 2944-dim features).

Source code in aigve/metrics/video_quality_assessment/nn_based/gstvqa/gstvqa_metric.py
@METRICS.register_module()
class GstVqa(BaseMetric):
    """GstVQA metric modified for the toy dataset. (Supporting 2944-dim features)."""

    def __init__(self, model_path: str):
        super(GstVqa, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.submodel_path = os.path.join(os.getcwd(), 'metrics/video_quality_assessment/nn_based/gstvqa')
        if not submodule_exists(self.submodel_path):
            add_git_submodule(
                repo_url='https://github.com/Baoliang93/GSTVQA.git', 
                submodule_path=self.submodel_path
            )
        from .GSTVQA.TCSVT_Release.GVQA_Release.GVQA_Cross.cross_test import GSTVQA as GSTVQA_model
        self.model = GSTVQA_model().to(self.device)
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.eval()
        # self.criterion = nn.L1Loss().to(self.device)

    def compute_stat_features(self, features: torch.Tensor, num_valid_frames: int) -> Tuple[torch.Tensor]:
        """Compute statistical features mean_var, std_var, mean_mean, std_mean from extracted deep features.

        Args:
            features (torch.Tensor): Tensor of shape [T, 2944].
            num_valid_frames (int): Number of valid frames before padding.

        Returns:
            Tuple[torch.Tensor]: (mean_var, std_var, mean_mean, std_mean), each of shape [1472].
        """
        # Ignore padded frames
        features = features[:num_valid_frames]  # Shape: [num_valid_frames, feature_dim]: [10, 1472]

        if num_valid_frames == 0:  # Edge case: all frames were padded
            return (
                torch.zeros(1472, device=self.device),
                torch.zeros(1472, device=self.device),
                torch.zeros(1472, device=self.device),
                torch.zeros(1472, device=self.device),
            )

        # Split into mean and std components
        mean_features = features[:, :1472]  # First 1472 features are mean-based
        std_features = features[:, 1472:]   # Last 1472 features are std-based

        # Compute per-feature statistics over frames
        mean_mean = mean_features.mean(dim=0)  # Shape: [1472]
        std_mean = std_features.mean(dim=0)    # Shape: [1472]
        mean_var = mean_features.var(dim=0, unbiased=False)  # Shape: [1472]
        std_var = std_features.var(dim=0, unbiased=False)    # Shape: [1472]

        return mean_var, std_var, mean_mean, std_mean

    def process(self, data_batch: Sequence, data_samples: Sequence) -> None:
        """
        Process a batch of extracted deep features for GSTVQA evaluation and store results in a JSON file.

        Args:
            data_batch (SequencTuplee): A batch of data from the dataloader (not used here).
            data_samples (List[ [torch.Tensor], Tuple[int], Tuple[str] ]): 
                A list containing three tuples:
                - A tuple of `deep_features`: Each item is a Tensor of shape [T, 2944]. 
                - A tuple of `num_frames`: Each item is an integer representing the number of valid frames.
                - A tuple of `video_name`: Each item is a string representing the file name for the video.
                The len of each three tuples are the batch size.
        """
        # data_samples an example: [
        #     (tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        #              [0., 0., 0.,  ..., 0., 0., 0.],
        #              ...
        #              [0., 0., 0.,  ..., 0., 0., 0.]]), 
        #      tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        #              [0., 0., 0.,  ..., 0., 0., 0.],
        #              ...
        #              [0., 0., 0.,  ..., 0., 0., 0.]])), 
        #     (10, 10)
        # ]
        results = []
        deep_features_tuple, num_frames_tuple, video_name_tuple = data_samples
        with torch.no_grad():
            for deep_features, num_valid_frames, video_name in zip(deep_features_tuple, num_frames_tuple, video_name_tuple):
                if not isinstance(deep_features, torch.Tensor) or not isinstance(num_valid_frames, int):
                    raise TypeError("Expected deep_features to be a torch.Tensor and num_valid_frames to be an int.")

                if num_valid_frames == 0:  # Edge case: No valid frames
                    results.append({"video_name": 'N/A', "GSTVQA_Score": 0.0})
                    continue

                # Remove padded features
                features = deep_features[:num_valid_frames].to(self.device)

                # Compute statistical features only on valid frames
                mean_var, std_var, mean_mean, std_mean = self.compute_stat_features(features, num_valid_frames)
                mean_var, std_var, mean_mean, std_mean = (
                    mean_var.to(self.device),
                    std_var.to(self.device),
                    mean_mean.to(self.device),
                    std_mean.to(self.device),
                )

                # Length tensor indicating the number of valid frames
                length = torch.tensor([num_valid_frames]).to(self.device)
                # print('features(input) shape', features.unsqueeze(1).shape) # torch.Size([10, 1, 1472])
                # print('input_length shape', length.shape) # torch.Size([1])
                # print('input_length', length) # torch.Size([1])
                # print('mean_mean shape', mean_mean.shape) # torch.Size([1472])
                # print('std_mean shape', std_mean.shape) # torch.Size([1472])
                # print('mean_var shape', mean_var.shape) # torch.Size([1472])
                # print('std_var shape', std_var.shape) # torch.Size([1472])

                # Run GSTVQA model
                outputs = self.model(features.unsqueeze(1), length, mean_var, std_var, mean_mean, std_mean)
                score = outputs.item()
                results.append({"video_name": video_name, "GSTVQA_Score": score})
                # print(f"Processed score {score:.4f} for {video_name}")

        self.results.extend(results)


    def compute_metrics(self, results: list) -> Dict[str, float]:
        """Compute final GSTVQA-based metrics."""
        scores = np.array([res['GSTVQA_Score'] for res in self.results])
        mean_score = np.mean(scores)
        print(f"GSTVQA mean score: {mean_score:.4f}")

        json_file_path = os.path.join(os.getcwd(), "gstvqa_results.json")
        final_results = {"video_results": self.results, "GSTVQA_Mean_Score": mean_score}
        with open(json_file_path, "w") as json_file:
            json.dump(final_results, json_file, indent=4)
        print(f"GSTVQA mean score saved to {json_file_path}")

        return {'GSTVQA_Mean_Score': mean_score}

compute_metrics(results)

Compute final GSTVQA-based metrics.

Source code in aigve/metrics/video_quality_assessment/nn_based/gstvqa/gstvqa_metric.py
def compute_metrics(self, results: list) -> Dict[str, float]:
    """Compute final GSTVQA-based metrics."""
    scores = np.array([res['GSTVQA_Score'] for res in self.results])
    mean_score = np.mean(scores)
    print(f"GSTVQA mean score: {mean_score:.4f}")

    json_file_path = os.path.join(os.getcwd(), "gstvqa_results.json")
    final_results = {"video_results": self.results, "GSTVQA_Mean_Score": mean_score}
    with open(json_file_path, "w") as json_file:
        json.dump(final_results, json_file, indent=4)
    print(f"GSTVQA mean score saved to {json_file_path}")

    return {'GSTVQA_Mean_Score': mean_score}

compute_stat_features(features, num_valid_frames)

Compute statistical features mean_var, std_var, mean_mean, std_mean from extracted deep features.

Parameters:

Name Type Description Default
features Tensor

Tensor of shape [T, 2944].

required
num_valid_frames int

Number of valid frames before padding.

required

Returns:

Type Description
Tuple[Tensor]

Tuple[torch.Tensor]: (mean_var, std_var, mean_mean, std_mean), each of shape [1472].

Source code in aigve/metrics/video_quality_assessment/nn_based/gstvqa/gstvqa_metric.py
def compute_stat_features(self, features: torch.Tensor, num_valid_frames: int) -> Tuple[torch.Tensor]:
    """Compute statistical features mean_var, std_var, mean_mean, std_mean from extracted deep features.

    Args:
        features (torch.Tensor): Tensor of shape [T, 2944].
        num_valid_frames (int): Number of valid frames before padding.

    Returns:
        Tuple[torch.Tensor]: (mean_var, std_var, mean_mean, std_mean), each of shape [1472].
    """
    # Ignore padded frames
    features = features[:num_valid_frames]  # Shape: [num_valid_frames, feature_dim]: [10, 1472]

    if num_valid_frames == 0:  # Edge case: all frames were padded
        return (
            torch.zeros(1472, device=self.device),
            torch.zeros(1472, device=self.device),
            torch.zeros(1472, device=self.device),
            torch.zeros(1472, device=self.device),
        )

    # Split into mean and std components
    mean_features = features[:, :1472]  # First 1472 features are mean-based
    std_features = features[:, 1472:]   # Last 1472 features are std-based

    # Compute per-feature statistics over frames
    mean_mean = mean_features.mean(dim=0)  # Shape: [1472]
    std_mean = std_features.mean(dim=0)    # Shape: [1472]
    mean_var = mean_features.var(dim=0, unbiased=False)  # Shape: [1472]
    std_var = std_features.var(dim=0, unbiased=False)    # Shape: [1472]

    return mean_var, std_var, mean_mean, std_mean

process(data_batch, data_samples)

Process a batch of extracted deep features for GSTVQA evaluation and store results in a JSON file.

Parameters:

Name Type Description Default
data_batch SequencTuplee

A batch of data from the dataloader (not used here).

required
data_samples List[[Tensor], Tuple[int], Tuple[str]]

A list containing three tuples: - A tuple of deep_features: Each item is a Tensor of shape [T, 2944]. - A tuple of num_frames: Each item is an integer representing the number of valid frames. - A tuple of video_name: Each item is a string representing the file name for the video. The len of each three tuples are the batch size.

required
Source code in aigve/metrics/video_quality_assessment/nn_based/gstvqa/gstvqa_metric.py
def process(self, data_batch: Sequence, data_samples: Sequence) -> None:
    """
    Process a batch of extracted deep features for GSTVQA evaluation and store results in a JSON file.

    Args:
        data_batch (SequencTuplee): A batch of data from the dataloader (not used here).
        data_samples (List[ [torch.Tensor], Tuple[int], Tuple[str] ]): 
            A list containing three tuples:
            - A tuple of `deep_features`: Each item is a Tensor of shape [T, 2944]. 
            - A tuple of `num_frames`: Each item is an integer representing the number of valid frames.
            - A tuple of `video_name`: Each item is a string representing the file name for the video.
            The len of each three tuples are the batch size.
    """
    # data_samples an example: [
    #     (tensor([[0., 0., 0.,  ..., 0., 0., 0.],
    #              [0., 0., 0.,  ..., 0., 0., 0.],
    #              ...
    #              [0., 0., 0.,  ..., 0., 0., 0.]]), 
    #      tensor([[0., 0., 0.,  ..., 0., 0., 0.],
    #              [0., 0., 0.,  ..., 0., 0., 0.],
    #              ...
    #              [0., 0., 0.,  ..., 0., 0., 0.]])), 
    #     (10, 10)
    # ]
    results = []
    deep_features_tuple, num_frames_tuple, video_name_tuple = data_samples
    with torch.no_grad():
        for deep_features, num_valid_frames, video_name in zip(deep_features_tuple, num_frames_tuple, video_name_tuple):
            if not isinstance(deep_features, torch.Tensor) or not isinstance(num_valid_frames, int):
                raise TypeError("Expected deep_features to be a torch.Tensor and num_valid_frames to be an int.")

            if num_valid_frames == 0:  # Edge case: No valid frames
                results.append({"video_name": 'N/A', "GSTVQA_Score": 0.0})
                continue

            # Remove padded features
            features = deep_features[:num_valid_frames].to(self.device)

            # Compute statistical features only on valid frames
            mean_var, std_var, mean_mean, std_mean = self.compute_stat_features(features, num_valid_frames)
            mean_var, std_var, mean_mean, std_mean = (
                mean_var.to(self.device),
                std_var.to(self.device),
                mean_mean.to(self.device),
                std_mean.to(self.device),
            )

            # Length tensor indicating the number of valid frames
            length = torch.tensor([num_valid_frames]).to(self.device)
            # print('features(input) shape', features.unsqueeze(1).shape) # torch.Size([10, 1, 1472])
            # print('input_length shape', length.shape) # torch.Size([1])
            # print('input_length', length) # torch.Size([1])
            # print('mean_mean shape', mean_mean.shape) # torch.Size([1472])
            # print('std_mean shape', std_mean.shape) # torch.Size([1472])
            # print('mean_var shape', mean_var.shape) # torch.Size([1472])
            # print('std_var shape', std_var.shape) # torch.Size([1472])

            # Run GSTVQA model
            outputs = self.model(features.unsqueeze(1), length, mean_var, std_var, mean_mean, std_mean)
            score = outputs.item()
            results.append({"video_name": video_name, "GSTVQA_Score": score})
            # print(f"Processed score {score:.4f} for {video_name}")

    self.results.extend(results)

LightVQAPlus

Bases: BaseMetric

LightVQA+ metric for evaluating video quality.

Source code in aigve/metrics/video_quality_assessment/nn_based/lightvqa_plus/lightvqa_plus_metric.py
@METRICS.register_module()
class LightVQAPlus(BaseMetric):
    """LightVQA+ metric for evaluating video quality."""

    def __init__(self, model_path: str, swin_weights: str, is_gpu: bool = True):
        super(LightVQAPlus, self).__init__()
        self.model_path = model_path
        self.swin_weights = swin_weights
        self.device = torch.device("cuda" if is_gpu else "cpu")

        self.submodel_path = os.path.join(os.getcwd(), 'metrics/video_quality_assessment/nn_based/lightvqa_plus')
        if not submodule_exists(self.submodel_path):
            add_git_submodule(
                repo_url='https://github.com/SaMMyCHoo/Light-VQA-plus.git', 
                submodule_path=self.submodel_path
            )
        lightvqa_path = os.path.join(self.submodel_path, "Light_VQA_plus")
        if lightvqa_path not in sys.path:
            sys.path.insert(0, lightvqa_path)

        from .Light_VQA_plus.final_fusion_model import swin_small_patch4_window7_224 as create_model
        self.model = create_model().to(self.device)

        weights_dict = torch.load(os.path.join(os.getcwd(), self.model_path), map_location=self.device)
        print(self.model.load_state_dict(weights_dict))

        self.model.eval()

    def process(self, data_batch: list, data_samples: list) -> None:
        """
        Process a batch of extracted deep features for LightVQA+ evaluation.
        Args:
            data_batch (Sequence): A batch of data from the dataloader (not used here).
            data_samples (List[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[str]]):
                A list containing five tuples:
                - spatial_features (torch.Tensor): Extracts 8 evenly spaced key frames. Shape: [8, 3, 672, 1120].
                - temporal_features (torch.Tensor): Motion features from SlowFast. Shape: [1, feature_dim(2304)].
                - bns_features (torch.Tensor): Brightness & Noise features. Shape: [8, 300].
                - bc_features (torch.Tensor): Temporal brightness contrast features. Shape: [8, final_dim(20)].
                - video_name (str): Video filename.
                The len of each tuples are the batch size.
        """
        results = []
        spatial_features_tuple, temporal_features_tuple, bns_features_tuple, bc_features_tuple, video_name_tuple = data_samples
        # print('spatial_features_tuple len: ', len(spatial_features_tuple)) # B
        # print('spatial_features_tuple[0]: ', spatial_features_tuple[0].shape) # torch.Size([8, 3, 672, 1120])
        # print('temporal_features_tuple[0]: ', temporal_features_tuple[0].shape) # torch.Size([1, 2304])
        # print('bns_features_tuple[0]: ', bns_features_tuple[0].shape) # torch.Size([8, 300])
        # print('bc_features_tuple[0]: ', bc_features_tuple[0].shape) # torch.Size([8, 20])

        batch_size = len(spatial_features_tuple)
        with torch.no_grad():
            for i in range(batch_size):
                video_name = video_name_tuple[i]
                spatial_features = spatial_features_tuple[i].to(self.device) # torch.Size([8, 3, 672, 1120])
                temporal_features = temporal_features_tuple[i].to(self.device) # torch.Size([1, 2304])
                bns_features = bns_features_tuple[i].to(self.device) # torch.Size([8, 300])
                bc_features = bc_features_tuple[i].to(self.device)  # Shape: [8, final_dim(20)]

                concat_features = torch.cat([temporal_features, bc_features.view(1, -1)], dim=1) # torch.Size([1, 2304+8*20])
                # print('concat_features: ', concat_features.shape) # torch.Size([1, 2464])
                final_temporal_features = F.pad(concat_features, (0, 2604 - concat_features.shape[1]), mode="constant", value=0) # torch.Size([1, 2604])
                # print('final_temporal_features: ', final_temporal_features.shape) # torch.Size([1, 2604])

                outputs = self.model(spatial_features, final_temporal_features, bns_features)
                # print('outputs: ', outputs)
                score = outputs.mean().item()

                results.append({"video_name": video_name, "LightVQAPlus_Score": score})
                print(f"Processed score {score:.4f} for {video_name}")

        self.results.extend(results)

    def compute_metrics(self, results: list) -> Dict[str, float]:
        """Compute final LightVQA+ metrics."""
        scores = np.array([res["LightVQAPlus_Score"] for res in self.results])
        mean_score = np.mean(scores) if scores.size > 0 else 0.0
        print(f"LightVQA+ mean score: {mean_score:.4f}")

        json_file_path = os.path.join(os.getcwd(), "lightvqaplus_results.json")
        final_results = {"video_results": self.results, "LightVQAPlus_Mean_Score": mean_score}
        with open(json_file_path, "w") as json_file:
            json.dump(final_results, json_file, indent=4)
        print(f"LightVQA+ mean score saved to {json_file_path}")

        return {"LightVQAPlus_Mean_Score": mean_score}

compute_metrics(results)

Compute final LightVQA+ metrics.

Source code in aigve/metrics/video_quality_assessment/nn_based/lightvqa_plus/lightvqa_plus_metric.py
def compute_metrics(self, results: list) -> Dict[str, float]:
    """Compute final LightVQA+ metrics."""
    scores = np.array([res["LightVQAPlus_Score"] for res in self.results])
    mean_score = np.mean(scores) if scores.size > 0 else 0.0
    print(f"LightVQA+ mean score: {mean_score:.4f}")

    json_file_path = os.path.join(os.getcwd(), "lightvqaplus_results.json")
    final_results = {"video_results": self.results, "LightVQAPlus_Mean_Score": mean_score}
    with open(json_file_path, "w") as json_file:
        json.dump(final_results, json_file, indent=4)
    print(f"LightVQA+ mean score saved to {json_file_path}")

    return {"LightVQAPlus_Mean_Score": mean_score}

process(data_batch, data_samples)

Process a batch of extracted deep features for LightVQA+ evaluation. Args: data_batch (Sequence): A batch of data from the dataloader (not used here). data_samples (List[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[str]]): A list containing five tuples: - spatial_features (torch.Tensor): Extracts 8 evenly spaced key frames. Shape: [8, 3, 672, 1120]. - temporal_features (torch.Tensor): Motion features from SlowFast. Shape: [1, feature_dim(2304)]. - bns_features (torch.Tensor): Brightness & Noise features. Shape: [8, 300]. - bc_features (torch.Tensor): Temporal brightness contrast features. Shape: [8, final_dim(20)]. - video_name (str): Video filename. The len of each tuples are the batch size.

Source code in aigve/metrics/video_quality_assessment/nn_based/lightvqa_plus/lightvqa_plus_metric.py
def process(self, data_batch: list, data_samples: list) -> None:
    """
    Process a batch of extracted deep features for LightVQA+ evaluation.
    Args:
        data_batch (Sequence): A batch of data from the dataloader (not used here).
        data_samples (List[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[str]]):
            A list containing five tuples:
            - spatial_features (torch.Tensor): Extracts 8 evenly spaced key frames. Shape: [8, 3, 672, 1120].
            - temporal_features (torch.Tensor): Motion features from SlowFast. Shape: [1, feature_dim(2304)].
            - bns_features (torch.Tensor): Brightness & Noise features. Shape: [8, 300].
            - bc_features (torch.Tensor): Temporal brightness contrast features. Shape: [8, final_dim(20)].
            - video_name (str): Video filename.
            The len of each tuples are the batch size.
    """
    results = []
    spatial_features_tuple, temporal_features_tuple, bns_features_tuple, bc_features_tuple, video_name_tuple = data_samples
    # print('spatial_features_tuple len: ', len(spatial_features_tuple)) # B
    # print('spatial_features_tuple[0]: ', spatial_features_tuple[0].shape) # torch.Size([8, 3, 672, 1120])
    # print('temporal_features_tuple[0]: ', temporal_features_tuple[0].shape) # torch.Size([1, 2304])
    # print('bns_features_tuple[0]: ', bns_features_tuple[0].shape) # torch.Size([8, 300])
    # print('bc_features_tuple[0]: ', bc_features_tuple[0].shape) # torch.Size([8, 20])

    batch_size = len(spatial_features_tuple)
    with torch.no_grad():
        for i in range(batch_size):
            video_name = video_name_tuple[i]
            spatial_features = spatial_features_tuple[i].to(self.device) # torch.Size([8, 3, 672, 1120])
            temporal_features = temporal_features_tuple[i].to(self.device) # torch.Size([1, 2304])
            bns_features = bns_features_tuple[i].to(self.device) # torch.Size([8, 300])
            bc_features = bc_features_tuple[i].to(self.device)  # Shape: [8, final_dim(20)]

            concat_features = torch.cat([temporal_features, bc_features.view(1, -1)], dim=1) # torch.Size([1, 2304+8*20])
            # print('concat_features: ', concat_features.shape) # torch.Size([1, 2464])
            final_temporal_features = F.pad(concat_features, (0, 2604 - concat_features.shape[1]), mode="constant", value=0) # torch.Size([1, 2604])
            # print('final_temporal_features: ', final_temporal_features.shape) # torch.Size([1, 2604])

            outputs = self.model(spatial_features, final_temporal_features, bns_features)
            # print('outputs: ', outputs)
            score = outputs.mean().item()

            results.append({"video_name": video_name, "LightVQAPlus_Score": score})
            print(f"Processed score {score:.4f} for {video_name}")

    self.results.extend(results)

LightVQAPlusDataset

Bases: Dataset

Dataset for LightVQA+. Extracts: - spatial_features (torch.Tensor): Extracted key frames. - temporal_features (torch.Tensor): SlowFast motion features. - BNS_features (torch.Tensor): Brightness & Noise features. - BC_features (torch.Tensor): Temporal CLIP-based brightness contrast features. - video_name (str): Video filename.

Source code in aigve/datasets/lightvqa_plus_dataset.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
@DATASETS.register_module()
class LightVQAPlusDataset(Dataset):
    """
    Dataset for LightVQA+.
    Extracts:
        - spatial_features (torch.Tensor): Extracted key frames.
        - temporal_features (torch.Tensor): SlowFast motion features.
        - BNS_features (torch.Tensor): Brightness & Noise features.
        - BC_features (torch.Tensor): Temporal CLIP-based brightness contrast features.
        - video_name (str): Video filename.
    """

    def __init__(self, video_dir, prompt_dir, min_video_seconds=8):
        super(LightVQAPlusDataset, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.video_dir = video_dir
        self.prompt_dir = prompt_dir
        self.min_video_seconds = min_video_seconds

        self.video_names = self._read_video_names()

        # Load CLIP model for BNS and BC features
        self.clip_model, _ = clip.load("ViT-B/32", device="cpu")
        self.preprocess = transforms.Normalize(
            (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
        )
        self.to_tensor = transforms.ToTensor()

        # CLIP text prompts
        self.text_B = clip.tokenize([  # brightness (B)
            "an underexposed photo", "a slightly underexposed photo",
            "a well-exposed photo", "a slightly overexposed photo", "an overexposed photo"
        ])

        self.text_N = clip.tokenize([ # noise (N)
            "a photo with no noise", "a photo with little noise",
            "a photo with considerable noise", "a photo with serious noise", "a photo with extreme noise"
        ])

        self.submodel_path = os.path.join(os.getcwd(), 'metrics/video_quality_assessment/nn_based/lightvqa_plus')
        if not submodule_exists(self.submodel_path):
            add_git_submodule(
                repo_url='https://github.com/SaMMyCHoo/Light-VQA-plus.git', 
                submodule_path=self.submodel_path
            )
        # original_path = os.path.join(self.submodel_path, "Light-VQA-plus")
        lightvqa_path = os.path.join(self.submodel_path, "Light_VQA_plus")
        # if os.path.exists(original_path) and not os.path.exists(lightvqa_path):
        #     os.rename(original_path, lightvqa_path)
        if lightvqa_path not in sys.path:
            sys.path.insert(0, lightvqa_path)
        # print(sys.path)

        # Load SlowFast model
        slowfast, _ = lazy_import()
        self.slowfast_model = slowfast()

    def _read_video_names(self):
        """Reads video names from the dataset JSON file."""
        with open(self.prompt_dir, 'r') as reader:
            read_data = json.load(reader)
        return [item['video_path_pd'].strip() for item in read_data["data_list"]]

    def __len__(self):
        return len(self.video_names)

    def extract_key_frames(self, video_path):
        """
        Extracts 8 evenly spaced key frames across the entire video duration.

        Args:
            video_path (str): Path to the video file.

        Returns:
            spatial_features (torch.Tensor): Shape [8, 3, 672, 1120] containing 8 key frames.
        """
        cap = cv2.VideoCapture(video_path)
        video_name = os.path.basename(video_path).split('.')[0]

        video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        if video_length >= 8:
            # Select 8 unique frame indices evenly spaced across the entire video
            frame_indices = np.round(np.linspace(0, video_length - 1, num=8)).astype(int)
        else:
            # Select all available frames and repeat the last one to reach 8
            frame_indices = list(range(video_length)) + [video_length - 1] * (8 - video_length)

        spatial_features = torch.zeros([8, 3, 672, 1120])  # Ensure exactly 8 frames
        transform = transforms.Compose([
            transforms.Resize([672, 1120]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        last_valid_frame = None
        for idx, frame_idx in enumerate(frame_indices):
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            if ret:
                frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                spatial_features[idx] = transform(frame)
                last_valid_frame = spatial_features[idx]
            elif last_valid_frame is not None:  # If total frames are less than 8, repeat the last valid frame
                spatial_features[idx] = last_valid_frame

        cap.release()
        # print('spatial_features: ', spatial_features.shape) # torch.Size([8, 3, 672, 1120])
        return spatial_features

    def get_global_sf(self, video_path) -> torch.Tensor:
        """Extracts global brightness & noise features across full video.

        Args:
            video_path (str): Path to video file.

        Returns:
            torch.Tensor: Extracted global features (Shape: [8, 150]).
        """
        cap = cv2.VideoCapture(video_path)
        video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        # print('video_length: ', video_length)  # 16

        frames = []
        for _ in range(video_length):
            ret, frame = cap.read()
            if ret:
                frame = cv2.resize(frame, (1120, 672))
                frames.append(frame)
        cap.release()

        if not frames:
            raise ValueError(f"Failed to extract frames from {video_path}")

        res = []
        length = len(frames)
        now = 0
        interval = 10  # Process 10 frames at a time
        while now + interval - 1 < length:
            final = [self.to_tensor(Image.fromarray(cv2.cvtColor(frames[i + now], cv2.COLOR_BGR2RGB)))
                    for i in range(interval)]

            # Step 1: Convert to tensor batch
            images = torch.stack(final, dim=0)  # Shape: [10, 3, 672, 1120]

            # Step 2: Unfold into patches (Strictly following GET_SF)
            images = images.unfold(2, 224, 224).unfold(3, 224, 224)  # Shape: [10, 3, 3, 5, 224, 224]
            images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [10, 5, 3, 3, 224, 224]
            images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [10, 15, 3, 224, 224]
            images = images.view(-1, 3, 224, 224)  # Shape: [150, 3, 224, 224]
            images = self.preprocess(images)  # Normalize for CLIP
            # print('images get_global_sf: ', images.shape) # torch.Size([10*15, 3, 224, 224])

            # Step 3: Extract features using CLIP
            with torch.no_grad():
                logits_N, _ = self.clip_model(images, self.text_N)
                logits_B, _ = self.clip_model(images, self.text_B)

            tmp_N = logits_N.softmax(dim=-1).view(interval, -1) * 10
            tmp_B = logits_B.softmax(dim=-1).view(interval, -1) * 10
            # print('tmp_N get_global_sf', tmp_N.shape) # torch.Size([10, 75])
            # print('tmp_B get_global_sf', tmp_B.shape) # torch.Size([10, 75])
            res.append(torch.cat([tmp_N, tmp_B], dim=1))
            now += interval

        # Handle remaining frames
        if length > now:
            final = [self.to_tensor(Image.fromarray(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB)))
                    for i in range(now, length)]

            images = torch.stack(final, dim=0)  # Shape: [remaining(6), 3, 672, 1120]
            images = images.unfold(2, 224, 224).unfold(3, 224, 224)  # Shape: [remaining, 3, 3, 5, 224, 224]
            images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [remaining, 5, 3, 3, 224, 224]
            images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [remaining, 15, 3, 224, 224]
            images = images.view(-1, 3, 224, 224)  # Shape: [remaining*15, 3, 224, 224]
            images = self.preprocess(images)

            with torch.no_grad():
                logits_N, _ = self.clip_model(images, self.text_N) # Shape: [remaining, 5(num_text_prompts)]
                logits_B, _ = self.clip_model(images, self.text_B) # Shape: [remaining, 5]
                # print('logits_N last get_global_sf', logits_N.shape) # torch.Size([6*15, 5])
                # print('logits_B last get_global_sf', logits_B.shape) #torch.Size([6*15, 5])

            tmp_N = logits_N.softmax(dim=-1).view(length - now, -1) * 10 # Shape: [remaining, 75]
            tmp_B = logits_B.softmax(dim=-1).view(length - now, -1) * 10 # Shape: [remaining, 75]
            # print('tmp_N last get_global_sf', tmp_N.shape)  # torch.Size([6, 75])
            # print('tmp_B last get_global_sf', tmp_B.shape)  # torch.Size([6, 75])

            res.append(torch.cat([tmp_N, tmp_B], dim=1))

        res = torch.cat(res, dim=0)  # Shape: [length, 150]
        # print('res: ', res.shape)  # torch.Size([16, 150]) for toy dataset

        # Step 4: Aggregate into 8 time slots
        chunk_size = length // 8
        final_res = [
            torch.mean(res[i * chunk_size: (i + 1) * chunk_size], dim=0) if i < 7 else torch.mean(res[7 * chunk_size:], dim=0)
            for i in range(8)
        ]

        return torch.stack(final_res, dim=0)  # Shape: [8, 150]

    def extract_bns_features(self, video_path):
        """Extracts Brightness & Noise Sensitivity (BNS) features using CLIP.
        Local Feature Extraction (res1) → Uses 8 key frames
        Global Feature Extraction (res2) → Uses all frames

        Args:
            video_path (str): Path to the video file.

        Returns:
            spatial_features (torch.Tensor): Extracted 8 evenly spaced key frames across the entire video duration.
                Shape [8, 3, 672, 1120] containing 8 key frames.
            final_res (torch.Tensor): Extracted BNS feature (Shape: [8, 300]).
        """
        # Local Feature Extraction Step 1: Extract key frames
        spatial_features = self.extract_key_frames(video_path) # Shape: [8, 3, 672, 1120]

        # Step 2: Apply unfolding transformation (Strictly following GET_S_F)
        images = spatial_features.unfold(2, 224, 224).unfold(3, 224, 224)  # Break into patches. Shape: [8, 3, 3, 5, 224, 224]
        images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [8, 5, 3, 3, 224, 224]
        images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [8, 15, 3, 224, 224]
        images = images.view(-1, 3, 224, 224)  # Shape: [120, 3, 224, 224]
        images = self.preprocess(images)  # Normalize for CLIP
        # print('images: ', images.shape) # torch.Size([120, 3, 224, 224])
        # print(images.device)
        # print(self.text_N.device)

        # Step 3: Pass through CLIP
        with torch.no_grad():
            logits_N, _ = self.clip_model(images, self.text_N)
            logits_B, _ = self.clip_model(images, self.text_B)

        res_N = logits_N.softmax(dim=-1).view(8, -1) * 10
        # print('res_N: ', res_N.shape) # torch.Size([8, 75])
        res_B = logits_B.softmax(dim=-1).view(8, -1) * 10
        # print('res_B: ', res_N.shape) # torch.Size([8, 75])
        res1 = torch.cat((res_N, res_B), dim=1)
        # print('res1: ', res1.shape) # torch.Size([8, 150])

        # Global Feature Extraction (GET_SF Equivalent)
        res2 = self.get_global_sf(video_path)
        # print('res2: ', res2.shape) # res2:  torch.Size([8, 150])

        # Split & Combine Features
        Nl, Bl = torch.split(res1, 75, dim=1)
        Ng, Bg = torch.split(res2, 75, dim=1)
        final_res = torch.cat([Nl, Ng, Bl, Bg], dim=1)
        # print('final_res: ', final_res.shape)

        return spatial_features, final_res  # Shape: [8, 300]

    def extract_bc_features(self, video_path) -> torch.Tensor:
        """
        Extracts Brightness Consistency features using CLIP-based temporal processing.

        Returns:
            torch.Tensor: Extracted BC feature (Shape: [8, final_dim]).
        """

        cap = cv2.VideoCapture(video_path)
        video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        frames = []
        for _ in range(video_length):
            ret, frame = cap.read()
            if ret:
                frame = cv2.resize(frame, (1120, 672))
                frames.append(frame)
        cap.release()

        if not frames:
            raise ValueError(f"Failed to extract frames from {video_path}")

        res = []
        now = 0
        interval = 10  # Process 10 frames at a time
        length = len(frames)

        # Step 1: Extract CLIP Features at Fixed Intervals
        while now + interval - 1 < length:
            batch = [self.to_tensor(Image.fromarray(cv2.cvtColor(frames[i + now], cv2.COLOR_BGR2RGB)))
                    for i in range(interval)]
            images = torch.stack(batch, dim=0)
            images = images.unfold(2, 224, 224).unfold(3, 224, 224)  # Shape: [10, 3, 3, 5, 224, 224]
            images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [10, 5, 3, 3, 224, 224]
            images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [10, 15, 3, 224, 224]
            images = images.view(-1, 3, 224, 224)  # Shape: [10*15, 3, 224, 224]
            images = self.preprocess(images)
            # print('images extract_bc_features', images.shape) # torch.Size([150, 3, 224, 224])

            with torch.no_grad():
                logits, _ = self.clip_model(images, self.text_B)

            tmp = logits.softmax(dim=-1) * 10
            res.append(tmp)
            now += interval

        # Handle Remaining Frames
        if length > now:
            batch = [self.to_tensor(Image.fromarray(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB)))
                    for i in range(now, length)]
            images = torch.stack(batch, dim=0)
            images = images.unfold(2, 224, 224).unfold(3, 224, 224)  # Shape: [remaining(6), 3, 3, 5, 224, 224]
            images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [remaining, 5, 3, 3, 224, 224]
            images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [remaining, 15, 3, 224, 224]
            images = images.view(-1, 3, 224, 224)  # Shape: [remaining, 15, 3, 224, 224]
            images = self.preprocess(images)
            # print('images: ', images.shape) #  torch.Size([6*15, 3, 224, 224])

            with torch.no_grad():
                logits, _ = self.clip_model(images, self.text_B)

            tmp = logits.softmax(dim=-1) * 10
            res.append(tmp)

        res = torch.cat(res, dim=0)  # Shape: [length, 5]
        # print('res extract_bc_features: ', res.shape) # torch.Size([150+90, 5])

        # Step 2: Multi-Scale Variance Computation: downsample frames steps
        # smaller step: Captures fast, fine-grained changes.
        # larger step:  Captures slow, long-term trends.
        final_res = []
        for step in [1, 2, 4, 8]:  # Multi-scale temporal steps
            chunk_number = 8 // step
            chunk_size = length // chunk_number
            chunks = []
            for i in range(chunk_number):
                if i < chunk_number - 1:
                    chunk = res[i * chunk_size : (i + 1) * chunk_size, :]
                else:
                    chunk = res[(chunk_number - 1) * chunk_size:, :]  # Handle remaining frames
                tmp = []
                for j in range(step):
                    temp = chunk[j::step, :]  
                    tmp.append(torch.var(temp.float(), dim=0))  # Variance computation
                chunks.append(tmp)  # final chunks len: 8; 4; 2; 1 
            final_res.append(chunks) # final final_res len: 4

        # Step 3: Aggregate Multi-Scale Features
        temp = []
        for i in range(8):  # Aggregate temporal information across 8 time slots
            temp.append(torch.cat(final_res[0][i]                                                # variance for step size = 1
                                + [torch.mean(torch.stack(final_res[1][i // 2], dim=0), dim=0)]  # for step size = 2
                                + [torch.mean(torch.stack(final_res[2][i // 4], dim=0), dim=0)]  # Every 4 slots share the same value.
                                + [torch.mean(torch.stack(final_res[3][i // 8], dim=0), dim=0)]  # for step size = 8
                                , dim=0))

        final_res = torch.stack(temp, dim=0)  # Shape: [8, final_dim]  
        # print('final_res extract_bc_featuresx: ', final_res.shape) # torch.Size([8, 20])

        return final_res

    def extract_temporal_features(self, video_path) -> torch.Tensor:
        """Extracts SlowFast motion features on the entire video segment.

        Args:
            video_path (str): Path to the video file.

        Returns:
            torch.Tensor: Extracted motion features (Shape: [1, feature_dim(2304)]).
        """
        cap = cv2.VideoCapture(video_path)
        video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_indices = np.round(np.linspace(0, video_length - 1, num=8)).astype(int)

        transform = transforms.Compose([
            transforms.Resize([224, 224]),  # Match SlowFast input size
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])  # Original normalization
        ])

        frames = []
        for idx in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                frames.append(transform(frame))  # Resize & normalize
        cap.release()

        if len(frames) < 8:
            raise ValueError(f"Insufficient frames in {video_path}, expected 8.")

        video_tensor = torch.stack(frames, dim=0)  # Shape: [8, 3, 224, 224]

        # Prepare for SlowFast input
        video_tensor = video_tensor.unsqueeze(0)  # Add batch dimension: [1, 8, 3, 224, 224]
        video_tensor = video_tensor.permute(0, 2, 1, 3, 4)  # Shape: [1, 3, 8, 224, 224]

        # Pack pathways for SlowFast model
        _, pack_pathway_output = lazy_import()
        inputs = pack_pathway_output(video_tensor, device='cpu')
        # print('inputs len: ', len(inputs))
        # print('inputs[0]: ', inputs[0].shape) # torch.Size([1, 3, 2, 224, 224])
        # print('inputs[1]: ', inputs[1].shape) # torch.Size([1, 3, 8, 224, 224])

        # Extract features using SlowFast
        with torch.no_grad():
            slow_feature, fast_feature = self.slowfast_model(inputs)

        # print('slow_feature extract_temporal_features: ', slow_feature.shape) # torch.Size([1, 2048, 1, 1, 1])
        # print('fast_feature extract_temporal_features: ', fast_feature.shape) # torch.Size([1, 256, 1, 1, 1])

        # Concatenate slow and fast features
        features = torch.cat([slow_feature, fast_feature], dim=1).squeeze(-1).squeeze(-1).squeeze(-1)
        # print('features extract_temporal_features: ', features.shape) # torch.Size([1, 2304])

        return features

    def __getitem__(self, index):
        """
        Returns:
            spatial_features (torch.Tensor): Spatial features. Shape: [8, 3, 672, 1120].
            bns_features (torch.Tensor): Brightness & Noise features. Shape: [8, 300].
            (bc_features (torch.Tensor): Temporal brightness contrast features. Shape: [8, final_dim].)
            temporal_features (torch.Tensor): SlowFast motion features. Shape: [1, feature_dim(2304)]
            video_name (str): Video filename.
        """
        video_name = self.video_names[index]
        video_path = os.path.join(self.video_dir, video_name)

        spatial_features, bns_features = self.extract_bns_features(video_path)
        bc_features = self.extract_bc_features(video_path)
        temporal_features = self.extract_temporal_features(video_path)

        return spatial_features, temporal_features, bns_features, bc_features, video_name

__getitem__(index)

Returns:

Name Type Description
spatial_features Tensor

Spatial features. Shape: [8, 3, 672, 1120].

bns_features Tensor

Brightness & Noise features. Shape: [8, 300].

bc_features (torch.Tensor

Temporal brightness contrast features. Shape: [8, final_dim].)

temporal_features Tensor

SlowFast motion features. Shape: [1, feature_dim(2304)]

video_name str

Video filename.

Source code in aigve/datasets/lightvqa_plus_dataset.py
def __getitem__(self, index):
    """
    Returns:
        spatial_features (torch.Tensor): Spatial features. Shape: [8, 3, 672, 1120].
        bns_features (torch.Tensor): Brightness & Noise features. Shape: [8, 300].
        (bc_features (torch.Tensor): Temporal brightness contrast features. Shape: [8, final_dim].)
        temporal_features (torch.Tensor): SlowFast motion features. Shape: [1, feature_dim(2304)]
        video_name (str): Video filename.
    """
    video_name = self.video_names[index]
    video_path = os.path.join(self.video_dir, video_name)

    spatial_features, bns_features = self.extract_bns_features(video_path)
    bc_features = self.extract_bc_features(video_path)
    temporal_features = self.extract_temporal_features(video_path)

    return spatial_features, temporal_features, bns_features, bc_features, video_name

extract_bc_features(video_path)

Extracts Brightness Consistency features using CLIP-based temporal processing.

Returns:

Type Description
Tensor

torch.Tensor: Extracted BC feature (Shape: [8, final_dim]).

Source code in aigve/datasets/lightvqa_plus_dataset.py
def extract_bc_features(self, video_path) -> torch.Tensor:
    """
    Extracts Brightness Consistency features using CLIP-based temporal processing.

    Returns:
        torch.Tensor: Extracted BC feature (Shape: [8, final_dim]).
    """

    cap = cv2.VideoCapture(video_path)
    video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    frames = []
    for _ in range(video_length):
        ret, frame = cap.read()
        if ret:
            frame = cv2.resize(frame, (1120, 672))
            frames.append(frame)
    cap.release()

    if not frames:
        raise ValueError(f"Failed to extract frames from {video_path}")

    res = []
    now = 0
    interval = 10  # Process 10 frames at a time
    length = len(frames)

    # Step 1: Extract CLIP Features at Fixed Intervals
    while now + interval - 1 < length:
        batch = [self.to_tensor(Image.fromarray(cv2.cvtColor(frames[i + now], cv2.COLOR_BGR2RGB)))
                for i in range(interval)]
        images = torch.stack(batch, dim=0)
        images = images.unfold(2, 224, 224).unfold(3, 224, 224)  # Shape: [10, 3, 3, 5, 224, 224]
        images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [10, 5, 3, 3, 224, 224]
        images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [10, 15, 3, 224, 224]
        images = images.view(-1, 3, 224, 224)  # Shape: [10*15, 3, 224, 224]
        images = self.preprocess(images)
        # print('images extract_bc_features', images.shape) # torch.Size([150, 3, 224, 224])

        with torch.no_grad():
            logits, _ = self.clip_model(images, self.text_B)

        tmp = logits.softmax(dim=-1) * 10
        res.append(tmp)
        now += interval

    # Handle Remaining Frames
    if length > now:
        batch = [self.to_tensor(Image.fromarray(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB)))
                for i in range(now, length)]
        images = torch.stack(batch, dim=0)
        images = images.unfold(2, 224, 224).unfold(3, 224, 224)  # Shape: [remaining(6), 3, 3, 5, 224, 224]
        images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [remaining, 5, 3, 3, 224, 224]
        images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [remaining, 15, 3, 224, 224]
        images = images.view(-1, 3, 224, 224)  # Shape: [remaining, 15, 3, 224, 224]
        images = self.preprocess(images)
        # print('images: ', images.shape) #  torch.Size([6*15, 3, 224, 224])

        with torch.no_grad():
            logits, _ = self.clip_model(images, self.text_B)

        tmp = logits.softmax(dim=-1) * 10
        res.append(tmp)

    res = torch.cat(res, dim=0)  # Shape: [length, 5]
    # print('res extract_bc_features: ', res.shape) # torch.Size([150+90, 5])

    # Step 2: Multi-Scale Variance Computation: downsample frames steps
    # smaller step: Captures fast, fine-grained changes.
    # larger step:  Captures slow, long-term trends.
    final_res = []
    for step in [1, 2, 4, 8]:  # Multi-scale temporal steps
        chunk_number = 8 // step
        chunk_size = length // chunk_number
        chunks = []
        for i in range(chunk_number):
            if i < chunk_number - 1:
                chunk = res[i * chunk_size : (i + 1) * chunk_size, :]
            else:
                chunk = res[(chunk_number - 1) * chunk_size:, :]  # Handle remaining frames
            tmp = []
            for j in range(step):
                temp = chunk[j::step, :]  
                tmp.append(torch.var(temp.float(), dim=0))  # Variance computation
            chunks.append(tmp)  # final chunks len: 8; 4; 2; 1 
        final_res.append(chunks) # final final_res len: 4

    # Step 3: Aggregate Multi-Scale Features
    temp = []
    for i in range(8):  # Aggregate temporal information across 8 time slots
        temp.append(torch.cat(final_res[0][i]                                                # variance for step size = 1
                            + [torch.mean(torch.stack(final_res[1][i // 2], dim=0), dim=0)]  # for step size = 2
                            + [torch.mean(torch.stack(final_res[2][i // 4], dim=0), dim=0)]  # Every 4 slots share the same value.
                            + [torch.mean(torch.stack(final_res[3][i // 8], dim=0), dim=0)]  # for step size = 8
                            , dim=0))

    final_res = torch.stack(temp, dim=0)  # Shape: [8, final_dim]  
    # print('final_res extract_bc_featuresx: ', final_res.shape) # torch.Size([8, 20])

    return final_res

extract_bns_features(video_path)

Extracts Brightness & Noise Sensitivity (BNS) features using CLIP. Local Feature Extraction (res1) → Uses 8 key frames Global Feature Extraction (res2) → Uses all frames

Parameters:

Name Type Description Default
video_path str

Path to the video file.

required

Returns:

Name Type Description
spatial_features Tensor

Extracted 8 evenly spaced key frames across the entire video duration. Shape [8, 3, 672, 1120] containing 8 key frames.

final_res Tensor

Extracted BNS feature (Shape: [8, 300]).

Source code in aigve/datasets/lightvqa_plus_dataset.py
def extract_bns_features(self, video_path):
    """Extracts Brightness & Noise Sensitivity (BNS) features using CLIP.
    Local Feature Extraction (res1) → Uses 8 key frames
    Global Feature Extraction (res2) → Uses all frames

    Args:
        video_path (str): Path to the video file.

    Returns:
        spatial_features (torch.Tensor): Extracted 8 evenly spaced key frames across the entire video duration.
            Shape [8, 3, 672, 1120] containing 8 key frames.
        final_res (torch.Tensor): Extracted BNS feature (Shape: [8, 300]).
    """
    # Local Feature Extraction Step 1: Extract key frames
    spatial_features = self.extract_key_frames(video_path) # Shape: [8, 3, 672, 1120]

    # Step 2: Apply unfolding transformation (Strictly following GET_S_F)
    images = spatial_features.unfold(2, 224, 224).unfold(3, 224, 224)  # Break into patches. Shape: [8, 3, 3, 5, 224, 224]
    images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [8, 5, 3, 3, 224, 224]
    images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [8, 15, 3, 224, 224]
    images = images.view(-1, 3, 224, 224)  # Shape: [120, 3, 224, 224]
    images = self.preprocess(images)  # Normalize for CLIP
    # print('images: ', images.shape) # torch.Size([120, 3, 224, 224])
    # print(images.device)
    # print(self.text_N.device)

    # Step 3: Pass through CLIP
    with torch.no_grad():
        logits_N, _ = self.clip_model(images, self.text_N)
        logits_B, _ = self.clip_model(images, self.text_B)

    res_N = logits_N.softmax(dim=-1).view(8, -1) * 10
    # print('res_N: ', res_N.shape) # torch.Size([8, 75])
    res_B = logits_B.softmax(dim=-1).view(8, -1) * 10
    # print('res_B: ', res_N.shape) # torch.Size([8, 75])
    res1 = torch.cat((res_N, res_B), dim=1)
    # print('res1: ', res1.shape) # torch.Size([8, 150])

    # Global Feature Extraction (GET_SF Equivalent)
    res2 = self.get_global_sf(video_path)
    # print('res2: ', res2.shape) # res2:  torch.Size([8, 150])

    # Split & Combine Features
    Nl, Bl = torch.split(res1, 75, dim=1)
    Ng, Bg = torch.split(res2, 75, dim=1)
    final_res = torch.cat([Nl, Ng, Bl, Bg], dim=1)
    # print('final_res: ', final_res.shape)

    return spatial_features, final_res  # Shape: [8, 300]

extract_key_frames(video_path)

Extracts 8 evenly spaced key frames across the entire video duration.

Parameters:

Name Type Description Default
video_path str

Path to the video file.

required

Returns:

Name Type Description
spatial_features Tensor

Shape [8, 3, 672, 1120] containing 8 key frames.

Source code in aigve/datasets/lightvqa_plus_dataset.py
def extract_key_frames(self, video_path):
    """
    Extracts 8 evenly spaced key frames across the entire video duration.

    Args:
        video_path (str): Path to the video file.

    Returns:
        spatial_features (torch.Tensor): Shape [8, 3, 672, 1120] containing 8 key frames.
    """
    cap = cv2.VideoCapture(video_path)
    video_name = os.path.basename(video_path).split('.')[0]

    video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    if video_length >= 8:
        # Select 8 unique frame indices evenly spaced across the entire video
        frame_indices = np.round(np.linspace(0, video_length - 1, num=8)).astype(int)
    else:
        # Select all available frames and repeat the last one to reach 8
        frame_indices = list(range(video_length)) + [video_length - 1] * (8 - video_length)

    spatial_features = torch.zeros([8, 3, 672, 1120])  # Ensure exactly 8 frames
    transform = transforms.Compose([
        transforms.Resize([672, 1120]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    last_valid_frame = None
    for idx, frame_idx in enumerate(frame_indices):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        if ret:
            frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            spatial_features[idx] = transform(frame)
            last_valid_frame = spatial_features[idx]
        elif last_valid_frame is not None:  # If total frames are less than 8, repeat the last valid frame
            spatial_features[idx] = last_valid_frame

    cap.release()
    # print('spatial_features: ', spatial_features.shape) # torch.Size([8, 3, 672, 1120])
    return spatial_features

extract_temporal_features(video_path)

Extracts SlowFast motion features on the entire video segment.

Parameters:

Name Type Description Default
video_path str

Path to the video file.

required

Returns:

Type Description
Tensor

torch.Tensor: Extracted motion features (Shape: [1, feature_dim(2304)]).

Source code in aigve/datasets/lightvqa_plus_dataset.py
def extract_temporal_features(self, video_path) -> torch.Tensor:
    """Extracts SlowFast motion features on the entire video segment.

    Args:
        video_path (str): Path to the video file.

    Returns:
        torch.Tensor: Extracted motion features (Shape: [1, feature_dim(2304)]).
    """
    cap = cv2.VideoCapture(video_path)
    video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_indices = np.round(np.linspace(0, video_length - 1, num=8)).astype(int)

    transform = transforms.Compose([
        transforms.Resize([224, 224]),  # Match SlowFast input size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])  # Original normalization
    ])

    frames = []
    for idx in frame_indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
        ret, frame = cap.read()
        if ret:
            frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            frames.append(transform(frame))  # Resize & normalize
    cap.release()

    if len(frames) < 8:
        raise ValueError(f"Insufficient frames in {video_path}, expected 8.")

    video_tensor = torch.stack(frames, dim=0)  # Shape: [8, 3, 224, 224]

    # Prepare for SlowFast input
    video_tensor = video_tensor.unsqueeze(0)  # Add batch dimension: [1, 8, 3, 224, 224]
    video_tensor = video_tensor.permute(0, 2, 1, 3, 4)  # Shape: [1, 3, 8, 224, 224]

    # Pack pathways for SlowFast model
    _, pack_pathway_output = lazy_import()
    inputs = pack_pathway_output(video_tensor, device='cpu')
    # print('inputs len: ', len(inputs))
    # print('inputs[0]: ', inputs[0].shape) # torch.Size([1, 3, 2, 224, 224])
    # print('inputs[1]: ', inputs[1].shape) # torch.Size([1, 3, 8, 224, 224])

    # Extract features using SlowFast
    with torch.no_grad():
        slow_feature, fast_feature = self.slowfast_model(inputs)

    # print('slow_feature extract_temporal_features: ', slow_feature.shape) # torch.Size([1, 2048, 1, 1, 1])
    # print('fast_feature extract_temporal_features: ', fast_feature.shape) # torch.Size([1, 256, 1, 1, 1])

    # Concatenate slow and fast features
    features = torch.cat([slow_feature, fast_feature], dim=1).squeeze(-1).squeeze(-1).squeeze(-1)
    # print('features extract_temporal_features: ', features.shape) # torch.Size([1, 2304])

    return features

get_global_sf(video_path)

Extracts global brightness & noise features across full video.

Parameters:

Name Type Description Default
video_path str

Path to video file.

required

Returns:

Type Description
Tensor

torch.Tensor: Extracted global features (Shape: [8, 150]).

Source code in aigve/datasets/lightvqa_plus_dataset.py
def get_global_sf(self, video_path) -> torch.Tensor:
    """Extracts global brightness & noise features across full video.

    Args:
        video_path (str): Path to video file.

    Returns:
        torch.Tensor: Extracted global features (Shape: [8, 150]).
    """
    cap = cv2.VideoCapture(video_path)
    video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    # print('video_length: ', video_length)  # 16

    frames = []
    for _ in range(video_length):
        ret, frame = cap.read()
        if ret:
            frame = cv2.resize(frame, (1120, 672))
            frames.append(frame)
    cap.release()

    if not frames:
        raise ValueError(f"Failed to extract frames from {video_path}")

    res = []
    length = len(frames)
    now = 0
    interval = 10  # Process 10 frames at a time
    while now + interval - 1 < length:
        final = [self.to_tensor(Image.fromarray(cv2.cvtColor(frames[i + now], cv2.COLOR_BGR2RGB)))
                for i in range(interval)]

        # Step 1: Convert to tensor batch
        images = torch.stack(final, dim=0)  # Shape: [10, 3, 672, 1120]

        # Step 2: Unfold into patches (Strictly following GET_SF)
        images = images.unfold(2, 224, 224).unfold(3, 224, 224)  # Shape: [10, 3, 3, 5, 224, 224]
        images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [10, 5, 3, 3, 224, 224]
        images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [10, 15, 3, 224, 224]
        images = images.view(-1, 3, 224, 224)  # Shape: [150, 3, 224, 224]
        images = self.preprocess(images)  # Normalize for CLIP
        # print('images get_global_sf: ', images.shape) # torch.Size([10*15, 3, 224, 224])

        # Step 3: Extract features using CLIP
        with torch.no_grad():
            logits_N, _ = self.clip_model(images, self.text_N)
            logits_B, _ = self.clip_model(images, self.text_B)

        tmp_N = logits_N.softmax(dim=-1).view(interval, -1) * 10
        tmp_B = logits_B.softmax(dim=-1).view(interval, -1) * 10
        # print('tmp_N get_global_sf', tmp_N.shape) # torch.Size([10, 75])
        # print('tmp_B get_global_sf', tmp_B.shape) # torch.Size([10, 75])
        res.append(torch.cat([tmp_N, tmp_B], dim=1))
        now += interval

    # Handle remaining frames
    if length > now:
        final = [self.to_tensor(Image.fromarray(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB)))
                for i in range(now, length)]

        images = torch.stack(final, dim=0)  # Shape: [remaining(6), 3, 672, 1120]
        images = images.unfold(2, 224, 224).unfold(3, 224, 224)  # Shape: [remaining, 3, 3, 5, 224, 224]
        images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [remaining, 5, 3, 3, 224, 224]
        images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [remaining, 15, 3, 224, 224]
        images = images.view(-1, 3, 224, 224)  # Shape: [remaining*15, 3, 224, 224]
        images = self.preprocess(images)

        with torch.no_grad():
            logits_N, _ = self.clip_model(images, self.text_N) # Shape: [remaining, 5(num_text_prompts)]
            logits_B, _ = self.clip_model(images, self.text_B) # Shape: [remaining, 5]
            # print('logits_N last get_global_sf', logits_N.shape) # torch.Size([6*15, 5])
            # print('logits_B last get_global_sf', logits_B.shape) #torch.Size([6*15, 5])

        tmp_N = logits_N.softmax(dim=-1).view(length - now, -1) * 10 # Shape: [remaining, 75]
        tmp_B = logits_B.softmax(dim=-1).view(length - now, -1) * 10 # Shape: [remaining, 75]
        # print('tmp_N last get_global_sf', tmp_N.shape)  # torch.Size([6, 75])
        # print('tmp_B last get_global_sf', tmp_B.shape)  # torch.Size([6, 75])

        res.append(torch.cat([tmp_N, tmp_B], dim=1))

    res = torch.cat(res, dim=0)  # Shape: [length, 150]
    # print('res: ', res.shape)  # torch.Size([16, 150]) for toy dataset

    # Step 4: Aggregate into 8 time slots
    chunk_size = length // 8
    final_res = [
        torch.mean(res[i * chunk_size: (i + 1) * chunk_size], dim=0) if i < 7 else torch.mean(res[7 * chunk_size:], dim=0)
        for i in range(8)
    ]

    return torch.stack(final_res, dim=0)  # Shape: [8, 150]

SimpleVQADataset

Bases: Dataset

Dataset for SimpleVQA. Each sample returns: - spatial_features (torch.Tensor): Extracted spatial frames. - motion_features (torch.Tensor): Extracted motion-based clips. - video_name (str): Video filename.

Source code in aigve/datasets/simplevqa_dataset.py
@DATASETS.register_module()
class SimpleVQADataset(Dataset):
    """
    Dataset for SimpleVQA.
    Each sample returns:
        - spatial_features (torch.Tensor): Extracted spatial frames.
        - motion_features (torch.Tensor): Extracted motion-based clips.
        - video_name (str): Video filename.
    """

    def __init__(self, video_dir, prompt_dir, min_video_seconds=8):
        super(SimpleVQADataset, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.video_dir = video_dir
        self.prompt_dir = prompt_dir
        self.min_video_seconds = min_video_seconds

        self.prompts, self.video_names = self._read_prompt_videoname()

    def _read_prompt_videoname(self):
        with open(self.prompt_dir, 'r') as reader:
            read_data = json.load(reader)

        prompt_data_list, video_name_list = [], []
        for item in read_data["data_list"]:
            prompt = item['prompt_gt'].strip()
            video_name = item['video_path_pd'].strip()
            prompt_data_list.append(prompt)
            video_name_list.append(video_name)

        return prompt_data_list, video_name_list

    def __len__(self):
        return len(self.prompts)

    def video_processing_spatial(self, video_path):
        """
        Extracts spatial frames with proper resizing and normalization.
            - Key frame extraction: It selects 1 frame per second.
            - Standard input size: It resizes frames to 448 * 448 (after an initial resize to 520).
        Return:
            transformed_video (torch.Tensor): shape[video_length_read, 3, 448, 448]. 
                `video_length_read` is total seconds of the video (though 2 for toy dataset) with minium 8 (i.e. min_video_seconds).
            video_name (str)
        """
        video_capture = cv2.VideoCapture(video_path)
        video_name = os.path.basename(video_path)
        video_length = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
        video_frame_rate = int(round(video_capture.get(cv2.CAP_PROP_FPS)))

        # Compute the number of total seconds of the video
        video_length_read = int(video_length/video_frame_rate) # math.ceil()
        # print('video_length_read (s): ', video_length_read)
        transformations = transforms.Compose([
            transforms.Resize(520),
            transforms.CenterCrop(448),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standard ImageNet normalization
        ])
        transformed_video = torch.zeros([max(video_length_read, self.min_video_seconds), 3, 448, 448])

        video_read_index = 0
        frame_idx = 0
        for i in range(video_length):
            has_frames, frame = video_capture.read()
            if has_frames:
                # Key frames extraction
                if (video_read_index < video_length_read) and (frame_idx % video_frame_rate == 0):
                    read_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                    read_frame = transformations(read_frame)
                    transformed_video[video_read_index] = read_frame
                    video_read_index += 1
                frame_idx += 1

        # Pads remaining frames by repeating the last available frame.
        if video_read_index < self.min_video_seconds:
            for i in range(video_read_index, self.min_video_seconds):
                transformed_video[i] = transformed_video[video_read_index - 1]

        video_capture.release()
        return transformed_video, video_name

    def video_processing_motion(self, video_path):
        """
        Extracts motion-based clips suitable for SlowFast.
            - Standard input size: It resizes frames to 224 * 224.
            - Motion-based clips: Processes at leaset 8-second clips, select 32 consecutive frames from each second.
        Return:
            transformed_video_all (List[torch.Tensor]): Tensor shape[video_length_clip(32), 3, 224, 224]. 
                Len(List) is total seconds of the video, with minium 8.
            video_name (str)
        """
        video_capture = cv2.VideoCapture(video_path)
        video_name = os.path.basename(video_path)
        video_length = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
        video_frame_rate = int(round(video_capture.get(cv2.CAP_PROP_FPS)))

        transform = transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225]) # General purpose
        ])
        transformed_frame_all = torch.zeros([video_length, 3, 224, 224])
        video_read_index = 0
        for i in range(video_length): # All frames extraction
            has_frames, frame = video_capture.read()
            if has_frames:
                read_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                read_frame = transform(read_frame)
                transformed_frame_all[video_read_index] = read_frame
                video_read_index += 1

        # Pads remaining frames by repeating the last available frame.
        if video_read_index < video_length: 
            for i in range(video_read_index, video_length):
                transformed_frame_all[i] = transformed_frame_all[video_read_index - 1]

        video_capture.release()

        # Compute the number of total seconds of the video
        video_clip = int(video_length/video_frame_rate)
        # print('video_clip (s): ', video_clip)
        video_length_clip = 32
        transformed_video_all = []

        # Extract motion-based clips: select 32 consecutive frames from each second
        for i in range(video_clip):
            transformed_video = torch.zeros([video_length_clip, 3, 224, 224])
            if (i * video_frame_rate + video_length_clip) <= video_length: # if the clip can be fully extracted, select 32 consecutive frames starting at i*video_frame_rate
                transformed_video = transformed_frame_all[i * video_frame_rate:(i * video_frame_rate + video_length_clip)]
            else: # Copy all rest available frames. Pads remaining frames by repeating the last available frame.
                transformed_video[:(video_length - i * video_frame_rate)] = transformed_frame_all[i * video_frame_rate:]
                for j in range((video_length - i * video_frame_rate), video_length_clip):
                    transformed_video[j] = transformed_video[video_length - i * video_frame_rate - 1]
            transformed_video_all.append(transformed_video)

        if video_clip < self.min_video_seconds:
            for i in range(video_clip, self.min_video_seconds):
                transformed_video_all.append(transformed_video_all[video_clip - 1])

        return transformed_video_all, video_name

    def __getitem__(self, index):
        """
        Returns:
            spatial_features (torch.Tensor): Shape [v_len_second, 3, 448, 448]
                `v_len_second` is total seconds of the video (though 2 for toy dataset) with minium 8 (i.e. min_video_seconds).
            motion_features (List[torch.Tensor]): List of motion feature tensors.
                Each tensor has shape [32, 3, 224, 224].
                Len(List) is total seconds of the video (i.e. v_len_second), with minium 8 (i.e. min_video_seconds).
            video_name (str): Video filename
        """
        video_name = self.video_names[index]
        video_path = os.path.join(self.video_dir, video_name)

        spatial_features, video_name = self.video_processing_spatial(video_path)
        motion_features, video_name = self.video_processing_motion(video_path)
        # print('spatial_features: ', spatial_features.shape) # torch.Size([8, 3, 448, 448]) for toy dataset
        # print('motion_features len: ', len(motion_features)) # 8
        # print('motion_features[0]: ', motion_features[0].shape) # torch.Size([32, 3, 224, 224])

        return spatial_features, motion_features, video_name

__getitem__(index)

Returns:

Name Type Description
spatial_features Tensor

Shape [v_len_second, 3, 448, 448] v_len_second is total seconds of the video (though 2 for toy dataset) with minium 8 (i.e. min_video_seconds).

motion_features List[Tensor]

List of motion feature tensors. Each tensor has shape [32, 3, 224, 224]. Len(List) is total seconds of the video (i.e. v_len_second), with minium 8 (i.e. min_video_seconds).

video_name str

Video filename

Source code in aigve/datasets/simplevqa_dataset.py
def __getitem__(self, index):
    """
    Returns:
        spatial_features (torch.Tensor): Shape [v_len_second, 3, 448, 448]
            `v_len_second` is total seconds of the video (though 2 for toy dataset) with minium 8 (i.e. min_video_seconds).
        motion_features (List[torch.Tensor]): List of motion feature tensors.
            Each tensor has shape [32, 3, 224, 224].
            Len(List) is total seconds of the video (i.e. v_len_second), with minium 8 (i.e. min_video_seconds).
        video_name (str): Video filename
    """
    video_name = self.video_names[index]
    video_path = os.path.join(self.video_dir, video_name)

    spatial_features, video_name = self.video_processing_spatial(video_path)
    motion_features, video_name = self.video_processing_motion(video_path)
    # print('spatial_features: ', spatial_features.shape) # torch.Size([8, 3, 448, 448]) for toy dataset
    # print('motion_features len: ', len(motion_features)) # 8
    # print('motion_features[0]: ', motion_features[0].shape) # torch.Size([32, 3, 224, 224])

    return spatial_features, motion_features, video_name

video_processing_motion(video_path)

Extracts motion-based clips suitable for SlowFast. - Standard input size: It resizes frames to 224 * 224. - Motion-based clips: Processes at leaset 8-second clips, select 32 consecutive frames from each second. Return: transformed_video_all (List[torch.Tensor]): Tensor shape[video_length_clip(32), 3, 224, 224]. Len(List) is total seconds of the video, with minium 8. video_name (str)

Source code in aigve/datasets/simplevqa_dataset.py
def video_processing_motion(self, video_path):
    """
    Extracts motion-based clips suitable for SlowFast.
        - Standard input size: It resizes frames to 224 * 224.
        - Motion-based clips: Processes at leaset 8-second clips, select 32 consecutive frames from each second.
    Return:
        transformed_video_all (List[torch.Tensor]): Tensor shape[video_length_clip(32), 3, 224, 224]. 
            Len(List) is total seconds of the video, with minium 8.
        video_name (str)
    """
    video_capture = cv2.VideoCapture(video_path)
    video_name = os.path.basename(video_path)
    video_length = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
    video_frame_rate = int(round(video_capture.get(cv2.CAP_PROP_FPS)))

    transform = transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225]) # General purpose
    ])
    transformed_frame_all = torch.zeros([video_length, 3, 224, 224])
    video_read_index = 0
    for i in range(video_length): # All frames extraction
        has_frames, frame = video_capture.read()
        if has_frames:
            read_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            read_frame = transform(read_frame)
            transformed_frame_all[video_read_index] = read_frame
            video_read_index += 1

    # Pads remaining frames by repeating the last available frame.
    if video_read_index < video_length: 
        for i in range(video_read_index, video_length):
            transformed_frame_all[i] = transformed_frame_all[video_read_index - 1]

    video_capture.release()

    # Compute the number of total seconds of the video
    video_clip = int(video_length/video_frame_rate)
    # print('video_clip (s): ', video_clip)
    video_length_clip = 32
    transformed_video_all = []

    # Extract motion-based clips: select 32 consecutive frames from each second
    for i in range(video_clip):
        transformed_video = torch.zeros([video_length_clip, 3, 224, 224])
        if (i * video_frame_rate + video_length_clip) <= video_length: # if the clip can be fully extracted, select 32 consecutive frames starting at i*video_frame_rate
            transformed_video = transformed_frame_all[i * video_frame_rate:(i * video_frame_rate + video_length_clip)]
        else: # Copy all rest available frames. Pads remaining frames by repeating the last available frame.
            transformed_video[:(video_length - i * video_frame_rate)] = transformed_frame_all[i * video_frame_rate:]
            for j in range((video_length - i * video_frame_rate), video_length_clip):
                transformed_video[j] = transformed_video[video_length - i * video_frame_rate - 1]
        transformed_video_all.append(transformed_video)

    if video_clip < self.min_video_seconds:
        for i in range(video_clip, self.min_video_seconds):
            transformed_video_all.append(transformed_video_all[video_clip - 1])

    return transformed_video_all, video_name

video_processing_spatial(video_path)

Extracts spatial frames with proper resizing and normalization. - Key frame extraction: It selects 1 frame per second. - Standard input size: It resizes frames to 448 * 448 (after an initial resize to 520). Return: transformed_video (torch.Tensor): shape[video_length_read, 3, 448, 448]. video_length_read is total seconds of the video (though 2 for toy dataset) with minium 8 (i.e. min_video_seconds). video_name (str)

Source code in aigve/datasets/simplevqa_dataset.py
def video_processing_spatial(self, video_path):
    """
    Extracts spatial frames with proper resizing and normalization.
        - Key frame extraction: It selects 1 frame per second.
        - Standard input size: It resizes frames to 448 * 448 (after an initial resize to 520).
    Return:
        transformed_video (torch.Tensor): shape[video_length_read, 3, 448, 448]. 
            `video_length_read` is total seconds of the video (though 2 for toy dataset) with minium 8 (i.e. min_video_seconds).
        video_name (str)
    """
    video_capture = cv2.VideoCapture(video_path)
    video_name = os.path.basename(video_path)
    video_length = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
    video_frame_rate = int(round(video_capture.get(cv2.CAP_PROP_FPS)))

    # Compute the number of total seconds of the video
    video_length_read = int(video_length/video_frame_rate) # math.ceil()
    # print('video_length_read (s): ', video_length_read)
    transformations = transforms.Compose([
        transforms.Resize(520),
        transforms.CenterCrop(448),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standard ImageNet normalization
    ])
    transformed_video = torch.zeros([max(video_length_read, self.min_video_seconds), 3, 448, 448])

    video_read_index = 0
    frame_idx = 0
    for i in range(video_length):
        has_frames, frame = video_capture.read()
        if has_frames:
            # Key frames extraction
            if (video_read_index < video_length_read) and (frame_idx % video_frame_rate == 0):
                read_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                read_frame = transformations(read_frame)
                transformed_video[video_read_index] = read_frame
                video_read_index += 1
            frame_idx += 1

    # Pads remaining frames by repeating the last available frame.
    if video_read_index < self.min_video_seconds:
        for i in range(video_read_index, self.min_video_seconds):
            transformed_video[i] = transformed_video[video_read_index - 1]

    video_capture.release()
    return transformed_video, video_name

SimpleVqa

Bases: BaseMetric

SimpleVQA metric for evaluating video quality.

Source code in aigve/metrics/video_quality_assessment/nn_based/simplevqa/simplevqa_metric.py
@METRICS.register_module()
class SimpleVqa(BaseMetric):
    """SimpleVQA metric for evaluating video quality."""
    def __init__(self, model_path: str, is_gpu: bool = True):
        super(SimpleVqa, self).__init__()
        self.model_path = model_path
        self.device = torch.device("cuda" if is_gpu else "cpu")
        self.submodel_path = os.path.join(os.getcwd(), 'metrics/video_quality_assessment/nn_based/simplevqa')
        if not submodule_exists(self.submodel_path):
            add_git_submodule(
                repo_url='https://github.com/sunwei925/SimpleVQA.git', 
                submodule_path=self.submodel_path
            )
        simplevqa_path = os.path.join(self.submodel_path, "SimpleVQA")
        if simplevqa_path not in sys.path:
            sys.path.insert(0, simplevqa_path)
        from .SimpleVQA.model import UGC_BVQA_model
        from .SimpleVQA.test_demo import slowfast
        self.model_motion = slowfast().to(self.device)
        self.model = UGC_BVQA_model.resnet50(pretrained=False)
        self.model = torch.nn.DataParallel(self.model).to(self.device)
        self.model.load_state_dict(torch.load(os.path.join(os.getcwd(), self.model_path), map_location=self.device))
        self.model.eval()

    def process(self, data_batch: list, data_samples: list) -> None:
        """
        Process a batch of extracted deep features for SimpleVQA evaluation.
        Args:
            data_batch (Sequence): A batch of data from the dataloader (not used here).
            data_samples (List[ Tuple[torch.Tensor], List[Tuple[torch.Tensor]], Tuple[str] ]):
                A list containing three tuples:
                - A tuple of `spatial_features` (torch.Tensor): Shape [v_len_second, 3, 448, 448]. 
                    `v_len_second` is total seconds of the video (though 2 for toy dataset) with minium 8 (i.e. min_video_seconds). 
                    The len of the tuple is the batch size. 
                - A list of `motion_features` (Tuple[torch.Tensor]): 
                    len(List) is total seconds of the video, with minium 8 (i.e. min_video_seconds).
                    Each item of the list is a Tuple of motion feature tensors. Each has shape [32, 3, 224, 224].
                    The len of the tuple is the batch size.
                - A tuple of `video_name` (str): Video filename. The len of the tuple is the batch size.
        """
        from .SimpleVQA.test_demo import pack_pathway_output

        results = []
        # print(type(data_samples)) # list
        spatial_features_tuple, motion_features_list, video_name_tuple = data_samples
        # print(len(spatial_features_tuple)) # 1
        # print(spatial_features_tuple[0].shape) # torch.Size([8, 3, 448, 448])

        # print(type(motion_features_list)) # List
        # print(len(motion_features_list)) # 8
        # print(type(motion_features_list[0])) # tuple
        # print(len(motion_features_list[0])) # 1
        # print(type(motion_features_list[0][0])) # Tensor
        # print(motion_features_list[0][0].shape) # torch.Size([32, 3, 224, 224])

        batch_size = len(spatial_features_tuple)
        with torch.no_grad():
            for i in range(batch_size):
                video_name = video_name_tuple[i]
                spatial_features = spatial_features_tuple[i].to(self.device).unsqueeze(0)  # Add batch dim. Shape: tensor with Size([1, v_len_second, 3, 448, 448])

                # Take the i-th element from each tuple in motion_features_list
                motion_features = [motion_features_list[j][i] for j in range(len(motion_features_list))] # Shape: List[tensor with Size([32, 3, 224, 224])], len of it is total seconds of the video, with minium 8.

                if not all(isinstance(mf, torch.Tensor) for mf in motion_features):
                    raise TypeError("Expected motion_features to be a list of tensors.")

                if len(motion_features) == 0:  # Edge case: No valid motion features
                    results.append({"video_name": video_name, "SimpleVQA_Score": 0.0})
                    continue

                n_clip = len(motion_features)  # 8
                feature_motion = torch.zeros([n_clip, 2048 + 256], device=self.device) 
                # Process each motion clip
                for idx, clip in enumerate(motion_features):
                    clip = clip.unsqueeze(dim=0).permute(0, 2, 1, 3, 4)  # Reshape to [1, C(3), T(32), H(224), W(224)]
                    clip = pack_pathway_output(clip, self.device)  # Convert to SlowFast format
                    slow_feature, fast_feature = self.model_motion(clip)
                    slow_feature = slow_feature.squeeze()
                    fast_feature = fast_feature.squeeze()

                    motion_feature = torch.cat([slow_feature, fast_feature]).unsqueeze(0)  # Shape: [1, 2304]
                    feature_motion[idx] = motion_feature 

                feature_motion = feature_motion.unsqueeze(0)  # Shape: [1, n_clip, 2304]

                outputs = self.model(spatial_features, feature_motion)
                score = outputs.item()

                results.append({"video_name": video_name, "SimpleVQA_Score": score})
                print(f"Processed score {score:.4f} for {video_name}")

        self.results.extend(results)

    def compute_metrics(self, results: list) -> Dict[str, float]:
        """Compute final SimpleVQA-based metrics."""
        scores = np.array([res["SimpleVQA_Score"] for res in self.results])
        mean_score = np.mean(scores) if scores.size > 0 else 0.0
        print(f"SimpleVQA mean score: {mean_score:.4f}")

        json_file_path = os.path.join(os.getcwd(), "simplevqa_results.json")
        final_results = {"video_results": self.results, "SimpleVQA_Mean_Score": mean_score}
        with open(json_file_path, "w") as json_file:
            json.dump(final_results, json_file, indent=4)
        print(f"SimpleVQA mean score saved to {json_file_path}")

        return {"SimpleVQA_Mean_Score": mean_score}

compute_metrics(results)

Compute final SimpleVQA-based metrics.

Source code in aigve/metrics/video_quality_assessment/nn_based/simplevqa/simplevqa_metric.py
def compute_metrics(self, results: list) -> Dict[str, float]:
    """Compute final SimpleVQA-based metrics."""
    scores = np.array([res["SimpleVQA_Score"] for res in self.results])
    mean_score = np.mean(scores) if scores.size > 0 else 0.0
    print(f"SimpleVQA mean score: {mean_score:.4f}")

    json_file_path = os.path.join(os.getcwd(), "simplevqa_results.json")
    final_results = {"video_results": self.results, "SimpleVQA_Mean_Score": mean_score}
    with open(json_file_path, "w") as json_file:
        json.dump(final_results, json_file, indent=4)
    print(f"SimpleVQA mean score saved to {json_file_path}")

    return {"SimpleVQA_Mean_Score": mean_score}

process(data_batch, data_samples)

Process a batch of extracted deep features for SimpleVQA evaluation. Args: data_batch (Sequence): A batch of data from the dataloader (not used here). data_samples (List[ Tuple[torch.Tensor], List[Tuple[torch.Tensor]], Tuple[str] ]): A list containing three tuples: - A tuple of spatial_features (torch.Tensor): Shape [v_len_second, 3, 448, 448]. v_len_second is total seconds of the video (though 2 for toy dataset) with minium 8 (i.e. min_video_seconds). The len of the tuple is the batch size. - A list of motion_features (Tuple[torch.Tensor]): len(List) is total seconds of the video, with minium 8 (i.e. min_video_seconds). Each item of the list is a Tuple of motion feature tensors. Each has shape [32, 3, 224, 224]. The len of the tuple is the batch size. - A tuple of video_name (str): Video filename. The len of the tuple is the batch size.

Source code in aigve/metrics/video_quality_assessment/nn_based/simplevqa/simplevqa_metric.py
def process(self, data_batch: list, data_samples: list) -> None:
    """
    Process a batch of extracted deep features for SimpleVQA evaluation.
    Args:
        data_batch (Sequence): A batch of data from the dataloader (not used here).
        data_samples (List[ Tuple[torch.Tensor], List[Tuple[torch.Tensor]], Tuple[str] ]):
            A list containing three tuples:
            - A tuple of `spatial_features` (torch.Tensor): Shape [v_len_second, 3, 448, 448]. 
                `v_len_second` is total seconds of the video (though 2 for toy dataset) with minium 8 (i.e. min_video_seconds). 
                The len of the tuple is the batch size. 
            - A list of `motion_features` (Tuple[torch.Tensor]): 
                len(List) is total seconds of the video, with minium 8 (i.e. min_video_seconds).
                Each item of the list is a Tuple of motion feature tensors. Each has shape [32, 3, 224, 224].
                The len of the tuple is the batch size.
            - A tuple of `video_name` (str): Video filename. The len of the tuple is the batch size.
    """
    from .SimpleVQA.test_demo import pack_pathway_output

    results = []
    # print(type(data_samples)) # list
    spatial_features_tuple, motion_features_list, video_name_tuple = data_samples
    # print(len(spatial_features_tuple)) # 1
    # print(spatial_features_tuple[0].shape) # torch.Size([8, 3, 448, 448])

    # print(type(motion_features_list)) # List
    # print(len(motion_features_list)) # 8
    # print(type(motion_features_list[0])) # tuple
    # print(len(motion_features_list[0])) # 1
    # print(type(motion_features_list[0][0])) # Tensor
    # print(motion_features_list[0][0].shape) # torch.Size([32, 3, 224, 224])

    batch_size = len(spatial_features_tuple)
    with torch.no_grad():
        for i in range(batch_size):
            video_name = video_name_tuple[i]
            spatial_features = spatial_features_tuple[i].to(self.device).unsqueeze(0)  # Add batch dim. Shape: tensor with Size([1, v_len_second, 3, 448, 448])

            # Take the i-th element from each tuple in motion_features_list
            motion_features = [motion_features_list[j][i] for j in range(len(motion_features_list))] # Shape: List[tensor with Size([32, 3, 224, 224])], len of it is total seconds of the video, with minium 8.

            if not all(isinstance(mf, torch.Tensor) for mf in motion_features):
                raise TypeError("Expected motion_features to be a list of tensors.")

            if len(motion_features) == 0:  # Edge case: No valid motion features
                results.append({"video_name": video_name, "SimpleVQA_Score": 0.0})
                continue

            n_clip = len(motion_features)  # 8
            feature_motion = torch.zeros([n_clip, 2048 + 256], device=self.device) 
            # Process each motion clip
            for idx, clip in enumerate(motion_features):
                clip = clip.unsqueeze(dim=0).permute(0, 2, 1, 3, 4)  # Reshape to [1, C(3), T(32), H(224), W(224)]
                clip = pack_pathway_output(clip, self.device)  # Convert to SlowFast format
                slow_feature, fast_feature = self.model_motion(clip)
                slow_feature = slow_feature.squeeze()
                fast_feature = fast_feature.squeeze()

                motion_feature = torch.cat([slow_feature, fast_feature]).unsqueeze(0)  # Shape: [1, 2304]
                feature_motion[idx] = motion_feature 

            feature_motion = feature_motion.unsqueeze(0)  # Shape: [1, n_clip, 2304]

            outputs = self.model(spatial_features, feature_motion)
            score = outputs.item()

            results.append({"video_name": video_name, "SimpleVQA_Score": score})
            print(f"Processed score {score:.4f} for {video_name}")

    self.results.extend(results)

TIFAScore

Bases: BaseMetric

Initialize the TIFAScore evaluator.

Parameters:

Name Type Description Default
openai_key str

The user's api key of the LLM models openai provides.

required
llm_model str

The name of the LLM model used in the TIFAScore evaluator. Defaults to gpt-3.5-turbo.

'gpt-3.5-turbo'
unifiedqa_model_name str

The name of the UnifiedQAModel used in TIFAScore evaluator. Defaults to allenai/unifiedqa-v2-t5-large-1363200.

'allenai/unifiedqa-v2-t5-large-1363200'
vqa_model_name str

The name of the AIGVEModel used in TIFAScore evaluator. Defaults to mplug-large.

'mplug-large'
Source code in aigve/metrics/text_video_alignment/gpt_based/TIFA/tifa_eval.py
@METRICS.register_module()
class TIFAScore(BaseMetric):
    """ Initialize the ``TIFAScore`` evaluator.

    Args:   
        openai_key (str): The user's api key of the LLM models openai provides.
        llm_model (str): The name of the LLM model used in the TIFAScore evaluator. Defaults to ``gpt-3.5-turbo``.
        unifiedqa_model_name (str): The name of the ``UnifiedQAModel`` used in TIFAScore evaluator. Defaults to ``allenai/unifiedqa-v2-t5-large-1363200``.
        vqa_model_name (str): The name of the ``AIGVEModel used`` in TIFAScore evaluator. Defaults to ``mplug-large``.
    """
    def __init__(self, 
                 openai_key,
                 llm_model: str = 'gpt-3.5-turbo',
                 unifiedqa_model_name: str = 'allenai/unifiedqa-v2-t5-large-1363200',
                 vqa_model_name: str = 'mplug-large'):
        super().__init__()

        self.openai_key = openai_key
        self.llm_model = llm_model
        self.unifiedqa_model_name = unifiedqa_model_name
        self.openai_completion, self.get_question_and_answers, self.filter_question_and_answers, self.unifiedqa_model, self.tifa_score_single, self.vqa_model = lazy_import()
        self.unifiedqa_model = self.UnifiedQAModel(self.unifiedqa_model_name)
        self.vqa_model_name = vqa_model_name
        self.vqa_model = self.AIGVEModel(self.vqa_model_name)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.openai_setup()

    def openai_setup(self):
        print('set up openai client')
        openai.api_key = self.openai_key
        assert openai.api_key is not None
        test_prompt_string = 'hello, how are you doing?'
        print('test prompt: ', test_prompt_string)
        response = self.openai_completion(
            test_prompt_string,
            model=self.llm_model,
        )
        print('test response: ', response)


    def process(self, data_batch: Sequence, data_samples: Sequence) -> None:
        """ TIFAScore process
        Process one batch of data samples and predictions. The processed
        results should be stored in ``self.results``, which will be used to
        compute the metrics when all batches have been processed.

        Args:
            data_batch (Sequence): A batch of data from the dataloader.
            data_samples (Sequence): A batch of data samples that
                contain annotations and predictions.
        """

        result = dict()

        input_prompts, input_videos = data_samples
        bsz = len(input_prompts)

        # Ensure prompt_input is a tensor
        if isinstance(input_prompts, tuple):
            input_prompts = list(input_prompts)

        if isinstance(input_videos, tuple):
            input_videos = list(input_videos)

        average_tifa_score_list = []
        for input_prompt, input_video in zip(input_prompts, input_videos):
            tifa_score = []
            # Generate questions with GPT-3.5-turbo
            gpt3_questions = self.get_question_and_answers(input_prompt)
            # print(gpt3_questions)
            # Filter questions with UnifiedQA
            filtered_questions = self.filter_question_and_answers(self.unifiedqa_model, gpt3_questions)
            for index, frame_path in enumerate(input_video):
                # calucluate TIFA score
                result = self.tifa_score_single(self.vqa_model, filtered_questions, frame_path)
                # print(result)
                tifa_score.append(result['tifa_score'])
            average_tifa_score = sum(tifa_score)/len(tifa_score)
            average_tifa_score_list.append(average_tifa_score)

        result['tifa_score'] = sum(average_tifa_score_list)/len(average_tifa_score_list)

        self.results.append(result)


    def compute_metrics(self, results: list) -> Dict[str, float]:
        """Compute the metrics from processed results.

        Args:
            results (list): The processed results of each batch.

        Returns:
            Dict[str, float]: The computed metrics. The keys are the names of
            the metrics, and the values are corresponding results.
        """
        logger: MMLogger = MMLogger.get_current_instance()

        tifa_score_np = np.zeros(len(results))
        for i, result in enumerate(results):
            tifa_score_np[i] = result['tifa_score']

        tifa_score_np_mean = np.mean(tifa_score_np) 

        print("Test results: tifa score={:.4f}"
              .format(tifa_score_np_mean))

        return result

compute_metrics(results)

Compute the metrics from processed results.

Parameters:

Name Type Description Default
results list

The processed results of each batch.

required

Returns:

Type Description
Dict[str, float]

Dict[str, float]: The computed metrics. The keys are the names of

Dict[str, float]

the metrics, and the values are corresponding results.

Source code in aigve/metrics/text_video_alignment/gpt_based/TIFA/tifa_eval.py
def compute_metrics(self, results: list) -> Dict[str, float]:
    """Compute the metrics from processed results.

    Args:
        results (list): The processed results of each batch.

    Returns:
        Dict[str, float]: The computed metrics. The keys are the names of
        the metrics, and the values are corresponding results.
    """
    logger: MMLogger = MMLogger.get_current_instance()

    tifa_score_np = np.zeros(len(results))
    for i, result in enumerate(results):
        tifa_score_np[i] = result['tifa_score']

    tifa_score_np_mean = np.mean(tifa_score_np) 

    print("Test results: tifa score={:.4f}"
          .format(tifa_score_np_mean))

    return result

process(data_batch, data_samples)

TIFAScore process Process one batch of data samples and predictions. The processed results should be stored in self.results, which will be used to compute the metrics when all batches have been processed.

Parameters:

Name Type Description Default
data_batch Sequence

A batch of data from the dataloader.

required
data_samples Sequence

A batch of data samples that contain annotations and predictions.

required
Source code in aigve/metrics/text_video_alignment/gpt_based/TIFA/tifa_eval.py
def process(self, data_batch: Sequence, data_samples: Sequence) -> None:
    """ TIFAScore process
    Process one batch of data samples and predictions. The processed
    results should be stored in ``self.results``, which will be used to
    compute the metrics when all batches have been processed.

    Args:
        data_batch (Sequence): A batch of data from the dataloader.
        data_samples (Sequence): A batch of data samples that
            contain annotations and predictions.
    """

    result = dict()

    input_prompts, input_videos = data_samples
    bsz = len(input_prompts)

    # Ensure prompt_input is a tensor
    if isinstance(input_prompts, tuple):
        input_prompts = list(input_prompts)

    if isinstance(input_videos, tuple):
        input_videos = list(input_videos)

    average_tifa_score_list = []
    for input_prompt, input_video in zip(input_prompts, input_videos):
        tifa_score = []
        # Generate questions with GPT-3.5-turbo
        gpt3_questions = self.get_question_and_answers(input_prompt)
        # print(gpt3_questions)
        # Filter questions with UnifiedQA
        filtered_questions = self.filter_question_and_answers(self.unifiedqa_model, gpt3_questions)
        for index, frame_path in enumerate(input_video):
            # calucluate TIFA score
            result = self.tifa_score_single(self.vqa_model, filtered_questions, frame_path)
            # print(result)
            tifa_score.append(result['tifa_score'])
        average_tifa_score = sum(tifa_score)/len(tifa_score)
        average_tifa_score_list.append(average_tifa_score)

    result['tifa_score'] = sum(average_tifa_score_list)/len(average_tifa_score_list)

    self.results.append(result)

ToyDataset

Bases: BaseDataset

ToyDataset for testing.

Parameters:

Name Type Description Default
data_root str

Root directory for data.

None
ann_file str

Annotation file path.

''
metainfo dict

Metadata information.

None
data_prefix dict

Prefix paths for different modalities.

None
pipeline List[Union[Callable, dict]]

Data transformation pipeline.

[]
modality dict

Specifies which modalities are used (video, text, image).

dict(use_video=True, use_text=True, use_image=False)
image_frame int

Number of frames for images.

None
Source code in aigve/datasets/toy_dataset.py
@DATASETS.register_module()
class ToyDataset(BaseDataset):
    """ToyDataset for testing.

    Args:
        data_root (str, optional): Root directory for data.
        ann_file (str): Annotation file path.
        metainfo (dict, optional): Metadata information.
        data_prefix (dict): Prefix paths for different modalities.
        pipeline (List[Union[Callable, dict]]): Data transformation pipeline.
        modality (dict): Specifies which modalities are used (video, text, image).
        image_frame (int, optional): Number of frames for images.
    """

    def __init__(self,
                 data_root: Optional[str] = None,
                 ann_file: str = '',
                 metainfo: Optional[dict] = None,
                 data_prefix: dict = None,
                 pipeline: List[Union[Callable, dict]] = [],
                 modality: dict = dict(use_video=True, use_text=True, use_image=False),
                 image_frame: int = None,
                 **kwargs) -> None:
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
            metainfo=metainfo,
            data_prefix=data_prefix,
            pipeline=pipeline,
            **kwargs
        )
        self.modality = modality
        self.image_frame = image_frame
        assert self.modality['use_video'] or self.modality['use_text'], (
            'Please specify the `modality` (`use_video` '
            f', `use_text`) for {self.__class__.__name__}')

    def parse_data_info(self, raw_data_info: dict) -> dict:
        """Parse raw data info."""
        info = {}
        info['img_frame'] = None
        if self.modality['use_text']:
            info['prompt_gt'] = osp.join(self.data_prefix.get('video', ''), 
                                         raw_data_info['prompt_gt'])

        if self.modality['use_video'] or self.modality['use_image']:
            info['video_path_pd'] = osp.join(self.data_prefix.get('video', ''), 
                                     raw_data_info['video_path_pd'])
            if self.modality['use_image']:
                info['img_frame'] = self.image_frame

        return info

parse_data_info(raw_data_info)

Parse raw data info.

Source code in aigve/datasets/toy_dataset.py
def parse_data_info(self, raw_data_info: dict) -> dict:
    """Parse raw data info."""
    info = {}
    info['img_frame'] = None
    if self.modality['use_text']:
        info['prompt_gt'] = osp.join(self.data_prefix.get('video', ''), 
                                     raw_data_info['prompt_gt'])

    if self.modality['use_video'] or self.modality['use_image']:
        info['video_path_pd'] = osp.join(self.data_prefix.get('video', ''), 
                                 raw_data_info['video_path_pd'])
        if self.modality['use_image']:
            info['img_frame'] = self.image_frame

    return info

VIEEvalScore

Bases: BaseMetric

Initialize the VIEEvalScore evaluator.

Parameters:

Name Type Description Default
llm_backbone str

The name of the LLM model used in the VIEEvalScore evaluator. Defaults to got4o.

'gpt4o'
api_key_path str

The user's api key path to initialize LLM models provides by openai.

'AIGVE_Tool/metrics/text_video_alignment/gpt_based/VIE/api_key.txt'
task str

The task the VIEEvalScore evaluator conducts. Defaults to ''t2v''.

't2v'
Source code in aigve/metrics/text_video_alignment/gpt_based/VIE/vie_eval.py
@METRICS.register_module()
class VIEEvalScore(BaseMetric):
    """ Initialize the ``VIEEvalScore`` evaluator.

    Args:
        llm_backbone (str): The name of the LLM model used in the VIEEvalScore evaluator. Defaults to ``got4o``.
        api_key_path (str): The user's api key path to initialize LLM models provides by openai.
        task (str): The task the VIEEvalScore evaluator conducts. Defaults to ''t2v''.
    """
    def __init__(self,
                 llm_backbone: str = "gpt4o",
                 api_key_path: str = 'AIGVE_Tool/metrics/text_video_alignment/gpt_based/VIE/api_key.txt',
                 task: str = 't2v',
                 ):
        super().__init__()

        self.api_key_path = api_key_path
        self.llm_backbone = llm_backbone
        self.task = task

        self.submodel_path = 'metrics/text_video_alignment/gpt_based/VIE'
        if not submodule_exists(self.submodel_path):
            add_git_submodule(
                repo_url='https://github.com/TIGER-AI-Lab/VIEScore.git', 
                submodule_path=self.submodel_path
            )  
        self.submodel_path = 'metrics/text_video_alignment/gpt_based/dsg'
        if not submodule_exists(self.submodel_path):
            add_git_submodule(
                repo_url='https://github.com/j-min/DSG.git', 
                submodule_path=self.submodel_path
            )  
        from .VIEScore.viescore import VIEScore 
        from .DSG.dsg.vqa_utils import MPLUG, InstructBLIP


        self.vie_score = VIEScore(backbone=self.llm_backbone, task=self.task, key_path=self.api_key_path)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    def process(self, data_batch: Sequence, data_samples: Sequence) -> None:
        """VIEScore process
        Process one batch of data samples and predictions. The processed
        results should be stored in ``self.results``, which will be used to
        compute the metrics when all batches have been processed.

        Args:
            data_batch (Sequence): A batch of data from the dataloader.
            data_samples (Sequence): A batch of data samples that
                contain annotations and predictions.
        """

        result = dict()

        input_prompts, input_videos = data_samples
        bsz = len(input_prompts)

        # Ensure prompt_input is a tensor
        if isinstance(input_prompts, tuple):
            input_prompts = list(input_prompts)

        if isinstance(input_videos, tuple):
            input_videos = list(input_videos)

        average_vie_score_list = []
        for input_prompt, input_video in zip(input_prompts, input_videos):
            vie_score_list = []
            for index, frame_path in enumerate(input_video):
                pil_image = Image.open(frame_path)
                score_list = self.vie_score.evaluate(pil_image, input_prompt)
                sementics_score, quality_score, overall_score = score_list
                vie_score_list.append(overall_score)
            average_vie_score = sum(vie_score_list)/len(vie_score_list)
            average_vie_score_list.append(average_vie_score)

        result['vie_score'] = sum(average_vie_score_list)/len(average_vie_score_list)

        self.results.append(result)


    def compute_metrics(self, results: list) -> Dict[str, float]:
        """Compute the metrics from processed results.

        Args:
            results (list): The processed results of each batch.

        Returns:
            Dict[str, float]: The computed metrics. The keys are the names of
            the metrics, and the values are corresponding results.
        """
        logger: MMLogger = MMLogger.get_current_instance()

        vie_score_np = np.zeros(len(results))
        for i, result in enumerate(results):
            vie_score_np[i] = result['vie_score']

        vie_score_np_mean = np.mean(vie_score_np) 

        print("Test results: vie score with dependency={:.4f}"
              .format(vie_score_np_mean))

        return result

compute_metrics(results)

Compute the metrics from processed results.

Parameters:

Name Type Description Default
results list

The processed results of each batch.

required

Returns:

Type Description
Dict[str, float]

Dict[str, float]: The computed metrics. The keys are the names of

Dict[str, float]

the metrics, and the values are corresponding results.

Source code in aigve/metrics/text_video_alignment/gpt_based/VIE/vie_eval.py
def compute_metrics(self, results: list) -> Dict[str, float]:
    """Compute the metrics from processed results.

    Args:
        results (list): The processed results of each batch.

    Returns:
        Dict[str, float]: The computed metrics. The keys are the names of
        the metrics, and the values are corresponding results.
    """
    logger: MMLogger = MMLogger.get_current_instance()

    vie_score_np = np.zeros(len(results))
    for i, result in enumerate(results):
        vie_score_np[i] = result['vie_score']

    vie_score_np_mean = np.mean(vie_score_np) 

    print("Test results: vie score with dependency={:.4f}"
          .format(vie_score_np_mean))

    return result

process(data_batch, data_samples)

VIEScore process Process one batch of data samples and predictions. The processed results should be stored in self.results, which will be used to compute the metrics when all batches have been processed.

Parameters:

Name Type Description Default
data_batch Sequence

A batch of data from the dataloader.

required
data_samples Sequence

A batch of data samples that contain annotations and predictions.

required
Source code in aigve/metrics/text_video_alignment/gpt_based/VIE/vie_eval.py
def process(self, data_batch: Sequence, data_samples: Sequence) -> None:
    """VIEScore process
    Process one batch of data samples and predictions. The processed
    results should be stored in ``self.results``, which will be used to
    compute the metrics when all batches have been processed.

    Args:
        data_batch (Sequence): A batch of data from the dataloader.
        data_samples (Sequence): A batch of data samples that
            contain annotations and predictions.
    """

    result = dict()

    input_prompts, input_videos = data_samples
    bsz = len(input_prompts)

    # Ensure prompt_input is a tensor
    if isinstance(input_prompts, tuple):
        input_prompts = list(input_prompts)

    if isinstance(input_videos, tuple):
        input_videos = list(input_videos)

    average_vie_score_list = []
    for input_prompt, input_video in zip(input_prompts, input_videos):
        vie_score_list = []
        for index, frame_path in enumerate(input_video):
            pil_image = Image.open(frame_path)
            score_list = self.vie_score.evaluate(pil_image, input_prompt)
            sementics_score, quality_score, overall_score = score_list
            vie_score_list.append(overall_score)
        average_vie_score = sum(vie_score_list)/len(vie_score_list)
        average_vie_score_list.append(average_vie_score)

    result['vie_score'] = sum(average_vie_score_list)/len(average_vie_score_list)

    self.results.append(result)

VideoPhy

Bases: BaseMetric

Source code in aigve/metrics/multi_aspect_metrics/videophy/videophy_metric.py
@METRICS.register_module()
class VideoPhy(BaseMetric):
    def __init__(self,
                hf_token: str,
                collect_device: Optional[Union[str, torch.device]] = None,
                prefix: Optional[str] = None,
                metric_path: str = None,
                model_path: str = 'videophysics/videocon_physics',
                datainfo_path: str = None,
                test_index: int = None,
                 **kwargs):

        """
        This function is used to initialize the VideoPhy metric.

        Args:
            collect_device (str or torch.device): The device to use for collecting the data
            prefix (str): The prefix to use for the metric name
            metric_path (str): The path to the metric
            model_path (str): The path to the model
            datainfo_path (str): The path to the data info
            test_index (int): The index of the test
        """

        super().__init__(collect_device=collect_device, prefix=prefix)
        # self.train_index = train_index
        self.metric_path = metric_path
        self.model_path = model_path
        self.datainfo_path = datainfo_path
        self.test_index = test_index
        self.hf_token = hf_token
        self.results = []

        # self.submodule_path = './metrics/aigve'
        # if not submodule_exists(self.submodule_path):
        #     add_git_submodule(
        #         repo_url='https://github.com/Hritikbansal/videophy.git',
        #         submodule_path=self.submodule_path
        #     )

        self.tokenizer = LlamaTokenizer.from_pretrained(self.model_path, token=self.hf_token)
        self.image_processor = MplugOwlImageProcessor.from_pretrained(self.model_path)
        self.processor = MplugOwlProcessor(self.image_processor, self.tokenizer)
        self.model = MplugOwlForConditionalGeneration.from_pretrained(
            self.model_path,
            torch_dtype=torch.bfloat16,
        ).to('cuda')
        self.model.eval()

    def get_entail(self, logits, input_ids):
        """
        This function is used to get the entailment scores.

        Args:
            logits (torch.Tensor): A tensor containing the logits
            input_ids (torch.Tensor): A tensor containing the input IDs
        """
        softmax = nn.Softmax(dim=2)
        logits = softmax(logits)
        token_id_yes = self.tokenizer.encode('Yes', add_special_tokens=False)[0]
        token_id_no = self.tokenizer.encode('No', add_special_tokens=False)[0]
        entailment = []
        for j in range(len(logits)):
            for i in range(len(input_ids[j])):
                if input_ids[j][i] == self.tokenizer.pad_token_id:  # pad token if the answer is not present
                    i = i - 1
                    break
                elif i == len(input_ids[j]) - 1:
                    break
            score = logits[j][i][token_id_yes] / (logits[j][i][token_id_yes] + logits[j][i][token_id_no])
            entailment.append(score)
        entailment = torch.stack(entailment)
        return entailment

    def get_logits(self, data_batch):
        """
        This function is used to get the logits for each input in the data batch.

        Args:
            data_batch (dict): A dictionary containing the data batch
        Returns:
            logits (torch.Tensor): A tensor containing the logits for each input in the data batch
        """
        # Iterate over each item in the data batch
        for k, v in data_batch.items():
            # Check if the item is a tensor
            if torch.is_tensor(v):
                # Convert float tensors to bfloat16
                if v.dtype == torch.float:
                    data_batch[k] = v.bfloat16()
                # Move the tensor to the model's device (e.g., GPU)
                data_batch[k] = data_batch[k].to(self.model.device)

        # print("Data batch: ", data_batch.keys())
        outputs = self.model(pixel_values=data_batch['pixel_values'], video_pixel_values=data_batch['video_pixel_values'],
                        labels=None, \
                        num_images=data_batch['num_images'], num_videos=data_batch['num_videos'], input_ids=data_batch['input_ids'],
                        non_padding_mask=data_batch['non_padding_mask'], \
                        non_media_mask=data_batch['non_media_mask'], prompt_mask=data_batch['prompt_mask'])
        logits = outputs['logits']
        return logits


    def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None:
        """
        This function is used to process the data batch and compute the metric.

        Args:
            data_batch (dict): A dictionary containing the data batch
            data_samples (list): A list of dictionaries containing the data samples
        """
        logits = self.get_logits(data_batch)
        entails_scores =  self.get_entail(logits, data_batch['input_ids'])

        self.results.extend(entails_scores.cpu().detach().to(torch.float32).numpy().tolist())
        # self.results = entails_scores.cpu().detach().to(torch.float32).numpy().tolist()
        # print(self.results)


    def compute_metrics(self, results: list) -> dict:
        """
        This function is used to compute the metrics.

        Args:
            results (list): A list of results
        """
        return {
            'entailment': float(np.mean(results))
        }

__init__(hf_token, collect_device=None, prefix=None, metric_path=None, model_path='videophysics/videocon_physics', datainfo_path=None, test_index=None, **kwargs)

This function is used to initialize the VideoPhy metric.

Parameters:

Name Type Description Default
collect_device str or device

The device to use for collecting the data

None
prefix str

The prefix to use for the metric name

None
metric_path str

The path to the metric

None
model_path str

The path to the model

'videophysics/videocon_physics'
datainfo_path str

The path to the data info

None
test_index int

The index of the test

None
Source code in aigve/metrics/multi_aspect_metrics/videophy/videophy_metric.py
def __init__(self,
            hf_token: str,
            collect_device: Optional[Union[str, torch.device]] = None,
            prefix: Optional[str] = None,
            metric_path: str = None,
            model_path: str = 'videophysics/videocon_physics',
            datainfo_path: str = None,
            test_index: int = None,
             **kwargs):

    """
    This function is used to initialize the VideoPhy metric.

    Args:
        collect_device (str or torch.device): The device to use for collecting the data
        prefix (str): The prefix to use for the metric name
        metric_path (str): The path to the metric
        model_path (str): The path to the model
        datainfo_path (str): The path to the data info
        test_index (int): The index of the test
    """

    super().__init__(collect_device=collect_device, prefix=prefix)
    # self.train_index = train_index
    self.metric_path = metric_path
    self.model_path = model_path
    self.datainfo_path = datainfo_path
    self.test_index = test_index
    self.hf_token = hf_token
    self.results = []

    # self.submodule_path = './metrics/aigve'
    # if not submodule_exists(self.submodule_path):
    #     add_git_submodule(
    #         repo_url='https://github.com/Hritikbansal/videophy.git',
    #         submodule_path=self.submodule_path
    #     )

    self.tokenizer = LlamaTokenizer.from_pretrained(self.model_path, token=self.hf_token)
    self.image_processor = MplugOwlImageProcessor.from_pretrained(self.model_path)
    self.processor = MplugOwlProcessor(self.image_processor, self.tokenizer)
    self.model = MplugOwlForConditionalGeneration.from_pretrained(
        self.model_path,
        torch_dtype=torch.bfloat16,
    ).to('cuda')
    self.model.eval()

compute_metrics(results)

This function is used to compute the metrics.

Parameters:

Name Type Description Default
results list

A list of results

required
Source code in aigve/metrics/multi_aspect_metrics/videophy/videophy_metric.py
def compute_metrics(self, results: list) -> dict:
    """
    This function is used to compute the metrics.

    Args:
        results (list): A list of results
    """
    return {
        'entailment': float(np.mean(results))
    }

get_entail(logits, input_ids)

This function is used to get the entailment scores.

Parameters:

Name Type Description Default
logits Tensor

A tensor containing the logits

required
input_ids Tensor

A tensor containing the input IDs

required
Source code in aigve/metrics/multi_aspect_metrics/videophy/videophy_metric.py
def get_entail(self, logits, input_ids):
    """
    This function is used to get the entailment scores.

    Args:
        logits (torch.Tensor): A tensor containing the logits
        input_ids (torch.Tensor): A tensor containing the input IDs
    """
    softmax = nn.Softmax(dim=2)
    logits = softmax(logits)
    token_id_yes = self.tokenizer.encode('Yes', add_special_tokens=False)[0]
    token_id_no = self.tokenizer.encode('No', add_special_tokens=False)[0]
    entailment = []
    for j in range(len(logits)):
        for i in range(len(input_ids[j])):
            if input_ids[j][i] == self.tokenizer.pad_token_id:  # pad token if the answer is not present
                i = i - 1
                break
            elif i == len(input_ids[j]) - 1:
                break
        score = logits[j][i][token_id_yes] / (logits[j][i][token_id_yes] + logits[j][i][token_id_no])
        entailment.append(score)
    entailment = torch.stack(entailment)
    return entailment

get_logits(data_batch)

This function is used to get the logits for each input in the data batch.

Parameters:

Name Type Description Default
data_batch dict

A dictionary containing the data batch

required

Returns: logits (torch.Tensor): A tensor containing the logits for each input in the data batch

Source code in aigve/metrics/multi_aspect_metrics/videophy/videophy_metric.py
def get_logits(self, data_batch):
    """
    This function is used to get the logits for each input in the data batch.

    Args:
        data_batch (dict): A dictionary containing the data batch
    Returns:
        logits (torch.Tensor): A tensor containing the logits for each input in the data batch
    """
    # Iterate over each item in the data batch
    for k, v in data_batch.items():
        # Check if the item is a tensor
        if torch.is_tensor(v):
            # Convert float tensors to bfloat16
            if v.dtype == torch.float:
                data_batch[k] = v.bfloat16()
            # Move the tensor to the model's device (e.g., GPU)
            data_batch[k] = data_batch[k].to(self.model.device)

    # print("Data batch: ", data_batch.keys())
    outputs = self.model(pixel_values=data_batch['pixel_values'], video_pixel_values=data_batch['video_pixel_values'],
                    labels=None, \
                    num_images=data_batch['num_images'], num_videos=data_batch['num_videos'], input_ids=data_batch['input_ids'],
                    non_padding_mask=data_batch['non_padding_mask'], \
                    non_media_mask=data_batch['non_media_mask'], prompt_mask=data_batch['prompt_mask'])
    logits = outputs['logits']
    return logits

process(data_batch, data_samples)

This function is used to process the data batch and compute the metric.

Parameters:

Name Type Description Default
data_batch dict

A dictionary containing the data batch

required
data_samples list

A list of dictionaries containing the data samples

required
Source code in aigve/metrics/multi_aspect_metrics/videophy/videophy_metric.py
def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None:
    """
    This function is used to process the data batch and compute the metric.

    Args:
        data_batch (dict): A dictionary containing the data batch
        data_samples (list): A list of dictionaries containing the data samples
    """
    logits = self.get_logits(data_batch)
    entails_scores =  self.get_entail(logits, data_batch['input_ids'])

    self.results.extend(entails_scores.cpu().detach().to(torch.float32).numpy().tolist())

VideoPhyDataset

Bases: Dataset

Source code in aigve/datasets/videophy_dataset.py
@DATASETS.register_module()
class VideoPhyDataset(Dataset):
    def __init__(self, data_path, video_root_path, hf_token, tokenizer=None, processor=None, max_length=2048, media_tokens=['<image>', '<|video|>'], hf_checkpoint='videophysics/videocon_physics'):
        """
        Args:
            data_path (str): Path to the data folder, it should be a json file
            tokenizer (Tokenizer): Tokenizer object
            processor (Processor): Processor object
            max_length (int): Maximum length of the input sequence
            media_tokens (list): List of media tokens
        """
        self.dataset = json.load(open(data_path))
        self.video_root_path = video_root_path

        self.hf_token = hf_token
        self.hf_checkpoint = hf_checkpoint
        self.max_length = max_length
        self.media_tokens = {k: -int(i + 1) for i, k in enumerate(media_tokens)}
        self.media_lengths = {'<image>': 1 + 64, '<|video|>': 1 + 64}
        self.bucket = {}


        # initialize tokenizer
        if tokenizer is not None:
            self.tokenizer = tokenizer
        else:
            self.tokenizer = LlamaTokenizer.from_pretrained(self.hf_checkpoint, token=self.hf_token)

        MplugOwlImageProcessor, MplugOwlProcessor = lazy_import_mplug_owl()
        self.image_processor = MplugOwlImageProcessor.from_pretrained(self.hf_checkpoint)
        # initialize processor
        if processor is not None:
            self.processor = processor
        else:
            self.processor = MplugOwlProcessor(self.image_processor, self.tokenizer)

    def __len__(self) -> int:
        """
        Returns:
            int: Length of the dataset
        """
        return self.dataset['metainfo']['length']

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the dataset
        Returns:
            dict: Dictionary containing the video, text, video path and caption
        """
        data = self.dataset['dataset_list'][idx]
        videopath = os.path.join(self.video_root_path, data['video_path_pd'])
        caption = data['prompt_gt']
        # video_input = self.processor(videos=[videopath], num_frames=16, return_tensors='pt') # video_pixel_values
        video_input = self.processor(videos=[videopath], num_frames=32, return_tensors='pt')  # video_pixel_values
        text_input = self._extract_text_token_from_conversation(caption, self.max_length, idx)
        item = {'video': video_input, 'text': text_input, 'videopath': videopath, 'caption': caption}
        return item

    def _extract_text_token_from_conversation(self, data, max_length, index):
        """
        Extracts the text tokens from the conversation
        Args:
            data (str): Conversation
            max_length (int): Maximum length of the input sequence
            index (int): Index of the dataset
        """
        # output enc_chunk
        enc_chunk = []

        if self.tokenizer.bos_token_id > 0:
            prompt_chunk = [self.tokenizer.bos_token_id]
        else:
            prompt_chunk = []

        # conversation = data["completion"]
        conversation = data

        # For Text only data
        if all([media_token not in conversation for media_token in self.media_tokens.keys()]):
            pattern = '|'.join(map(re.escape, ['AI: ', '\nHuman: ']))
            chunk_strs = re.split(f'({pattern})', conversation)
            prompt_length = -1
            stop_flag = False
            for idx, chunk_str in enumerate(chunk_strs):
                if idx == 0:
                    enc_chunk = prompt_chunk + \
                                self.tokenizer(chunk_str, add_special_tokens=False)[
                                    'input_ids']
                    enc_length = len(enc_chunk)
                    label_chunk = [0] * enc_length
                else:
                    if chunk_strs[idx - 1] == 'AI: ':
                        curr_chunk = self.tokenizer(
                            chunk_str, add_special_tokens=False)['input_ids']
                        if enc_length + len(curr_chunk) >= max_length:
                            curr_chunk = curr_chunk[:max_length - enc_length]
                            stop_flag = True
                        curr_chunk += [self.tokenizer.eos_token_id]
                        enc_length += len(curr_chunk)
                        enc_chunk += curr_chunk
                        label_chunk += [1] * len(curr_chunk)
                    else:
                        curr_chunk = self.tokenizer(
                            chunk_str, add_special_tokens=False)['input_ids']
                        if enc_length + len(curr_chunk) >= max_length + 1:
                            curr_chunk = curr_chunk[:max_length + 1 - enc_length]
                            stop_flag = True
                        enc_length += len(curr_chunk)
                        enc_chunk += curr_chunk
                        label_chunk += [0] * len(curr_chunk)
                    if stop_flag:
                        break

        # For Image-Text Data
        else:
            enc_length = 0
            prompt_length = -2
            pattern = '|'.join(
                map(re.escape, list(self.media_tokens.keys()) + ['AI: ', '\nHuman: ']))
            chunk_strs = re.split(f'({pattern})', conversation)
            chunk_strs = [x for x in chunk_strs if len(x) > 0]
            for idx, chunk_str in enumerate(chunk_strs):
                if enc_length >= max_length + 1:
                    break

                if idx == 0:
                    enc_chunk = prompt_chunk + \
                                self.tokenizer(chunk_str, add_special_tokens=False)[
                                    'input_ids']
                    enc_length = len(enc_chunk)
                    label_chunk = [0] * enc_length
                else:
                    if chunk_str in self.media_tokens:
                        # [CLS] + 256 + [EOS]
                        if enc_length + self.media_lengths[chunk_str] > max_length + 1:
                            break
                        else:
                            enc_chunk += [self.media_tokens[chunk_str]
                                          ] * self.media_lengths[chunk_str]
                            enc_length += self.media_lengths[chunk_str]
                            label_chunk += [0] * self.media_lengths[chunk_str]
                    else:

                        if chunk_strs[idx - 1] == 'AI: ':
                            curr_chunk = self.tokenizer(
                                chunk_str, add_special_tokens=False)['input_ids']
                            if enc_length + len(curr_chunk) >= max_length:
                                curr_chunk = curr_chunk[:max_length - enc_length]
                            curr_chunk += [self.tokenizer.eos_token_id]
                            enc_length += len(curr_chunk)
                            enc_chunk += curr_chunk
                            label_chunk += [1] * len(curr_chunk)
                        else:
                            curr_chunk = self.tokenizer(
                                chunk_str, add_special_tokens=False)['input_ids']
                            if enc_length + len(curr_chunk) >= max_length + 1:
                                curr_chunk = curr_chunk[:max_length +
                                                         1 - enc_length]
                            enc_length += len(curr_chunk)
                            enc_chunk += curr_chunk
                            label_chunk += [0] * len(curr_chunk)

        if enc_length < max_length + 1:
            padding_chunk = [self.tokenizer.pad_token_id] * \
                            (max_length + 1 - enc_length)
            padding_length = len(padding_chunk)
            label_chunk += [0] * (max_length + 1 - enc_length)
            enc_chunk = enc_chunk + padding_chunk
        else:
            padding_length = 0

        assert enc_length + padding_length == max_length + \
               1, (index, prompt_length, enc_length,
                   padding_length, max_length + 1)
        assert len(label_chunk) == max_length + \
               1, (len(label_chunk), max_length + 1)
        non_padding_mask = [1 if i < enc_length -
                                 1 else 0 for i in range(max_length)]

        enc_chunk = torch.tensor(enc_chunk).long()
        non_padding_mask = torch.tensor(non_padding_mask).long()
        prompt_mask = torch.tensor(label_chunk)[1:].long()
        prompt_length = torch.tensor([prompt_length]).long()

        # Create loss mask
        if all([media_token not in conversation for media_token in self.media_tokens.keys()]):
            non_media_mask = torch.ones_like(non_padding_mask).long()
        else:
            tmp_enc_chunk = enc_chunk.clone()
            tmp_enc_chunk[tmp_enc_chunk >= 0] = 1
            tmp_enc_chunk[tmp_enc_chunk < 0] = 0
            non_media_mask = torch.tensor(tmp_enc_chunk).long()
            non_media_mask = non_media_mask[1:].long()
        return {'input_ids': enc_chunk, "prompt_length": prompt_length, 'seq_length': enc_length,
                "non_padding_mask": non_padding_mask, 'non_media_mask': non_media_mask, 'prompt_mask': prompt_mask}

__getitem__(idx)

Parameters:

Name Type Description Default
idx int

Index of the dataset

required

Returns: dict: Dictionary containing the video, text, video path and caption

Source code in aigve/datasets/videophy_dataset.py
def __getitem__(self, idx):
    """
    Args:
        idx (int): Index of the dataset
    Returns:
        dict: Dictionary containing the video, text, video path and caption
    """
    data = self.dataset['dataset_list'][idx]
    videopath = os.path.join(self.video_root_path, data['video_path_pd'])
    caption = data['prompt_gt']
    # video_input = self.processor(videos=[videopath], num_frames=16, return_tensors='pt') # video_pixel_values
    video_input = self.processor(videos=[videopath], num_frames=32, return_tensors='pt')  # video_pixel_values
    text_input = self._extract_text_token_from_conversation(caption, self.max_length, idx)
    item = {'video': video_input, 'text': text_input, 'videopath': videopath, 'caption': caption}
    return item

__init__(data_path, video_root_path, hf_token, tokenizer=None, processor=None, max_length=2048, media_tokens=['<image>', '<|video|>'], hf_checkpoint='videophysics/videocon_physics')

Parameters:

Name Type Description Default
data_path str

Path to the data folder, it should be a json file

required
tokenizer Tokenizer

Tokenizer object

None
processor Processor

Processor object

None
max_length int

Maximum length of the input sequence

2048
media_tokens list

List of media tokens

['<image>', '<|video|>']
Source code in aigve/datasets/videophy_dataset.py
def __init__(self, data_path, video_root_path, hf_token, tokenizer=None, processor=None, max_length=2048, media_tokens=['<image>', '<|video|>'], hf_checkpoint='videophysics/videocon_physics'):
    """
    Args:
        data_path (str): Path to the data folder, it should be a json file
        tokenizer (Tokenizer): Tokenizer object
        processor (Processor): Processor object
        max_length (int): Maximum length of the input sequence
        media_tokens (list): List of media tokens
    """
    self.dataset = json.load(open(data_path))
    self.video_root_path = video_root_path

    self.hf_token = hf_token
    self.hf_checkpoint = hf_checkpoint
    self.max_length = max_length
    self.media_tokens = {k: -int(i + 1) for i, k in enumerate(media_tokens)}
    self.media_lengths = {'<image>': 1 + 64, '<|video|>': 1 + 64}
    self.bucket = {}


    # initialize tokenizer
    if tokenizer is not None:
        self.tokenizer = tokenizer
    else:
        self.tokenizer = LlamaTokenizer.from_pretrained(self.hf_checkpoint, token=self.hf_token)

    MplugOwlImageProcessor, MplugOwlProcessor = lazy_import_mplug_owl()
    self.image_processor = MplugOwlImageProcessor.from_pretrained(self.hf_checkpoint)
    # initialize processor
    if processor is not None:
        self.processor = processor
    else:
        self.processor = MplugOwlProcessor(self.image_processor, self.tokenizer)

__len__()

Returns:

Name Type Description
int int

Length of the dataset

Source code in aigve/datasets/videophy_dataset.py
def __len__(self) -> int:
    """
    Returns:
        int: Length of the dataset
    """
    return self.dataset['metainfo']['length']

VideoScore

Bases: BaseMetric

Source code in aigve/metrics/multi_aspect_metrics/videoscore/videoscore_metric.py
@METRICS.register_module()
class VideoScore(BaseMetric):
    def __init__(self,
                collect_device: Optional[Union[str, torch.device]] = None,
                prefix: Optional[str] = None,
                metric_path: str = None,
                model_path: str = 'TIGER-Lab/VideoScore-v1.1',
                datainfo_path: str = None,
                test_index: int = None,
                 **kwargs):
        """
        Args:
            collect_device (Optional[Union[str, torch.device]]): The device to collect the data on.
            prefix (Optional[str]): The prefix to use for the metric.
            metric_path (str): The path to the metric file.
            model_path (str): The path to the model file.
            datainfo_path (str): The path to the datainfo file.
            test_index (int): The index of the test data.
        """
        super().__init__(collect_device=collect_device, prefix=prefix)
        # self.train_index = train_index
        # TODO: ARE THERE PARAMETERS REQUIRED FOR THIS METRIC?
        self.metric_path = metric_path
        self.model_path = model_path
        self.datainfo_path = datainfo_path
        self.test_index = test_index


        self.model = Idefics2ForSequenceClassification.from_pretrained(self.model_path, torch_dtype=torch.bfloat16).eval()
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.model.to(self.device)

        self.results = []

    def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None:
        """
        Args:
            data_batch (Any): The data batch to process.
            data_samples (Sequence[dict]): The data samples to process.
        """


        data_batch = {k: v[0].to(self.model.device) for k, v in data_batch.items()}

        with torch.no_grad():
            outputs = self.model(**data_batch)

        logits = outputs.logits.cpu().detach().to(torch.float32).numpy()
        num_aspects = logits.shape[-1]

        aspect_scores = []
        for i in range(num_aspects):
            aspect_scores.append(round(logits[0, i].item(), 3))

        self.results.append(aspect_scores)

    def compute_metrics(self, results: list) -> dict:
        """
        Args:
            results (list): The results to compute the metrics from.
        """
        results = np.array(results)
        mean_scores = np.mean(results, axis=1)

        return {'visual_quailty': results[:, 0].tolist(),
                'temporal_consistency': results[:, 1].tolist(),
                'dynamic_degree': results[:, 2].tolist(),
                'text-to-video_alignment': results[:, 3].tolist(),
                'factual_consistency': results[:, 4].tolist(),
                'summary': {'visual_quality': mean_scores[0], 'temporal_consistency': mean_scores[1],
                            'dynamic_degree': mean_scores[2], 'text-to-video_alignment': mean_scores[3],
                            'factual_consistency': mean_scores[4]}}

__init__(collect_device=None, prefix=None, metric_path=None, model_path='TIGER-Lab/VideoScore-v1.1', datainfo_path=None, test_index=None, **kwargs)

Parameters:

Name Type Description Default
collect_device Optional[Union[str, device]]

The device to collect the data on.

None
prefix Optional[str]

The prefix to use for the metric.

None
metric_path str

The path to the metric file.

None
model_path str

The path to the model file.

'TIGER-Lab/VideoScore-v1.1'
datainfo_path str

The path to the datainfo file.

None
test_index int

The index of the test data.

None
Source code in aigve/metrics/multi_aspect_metrics/videoscore/videoscore_metric.py
def __init__(self,
            collect_device: Optional[Union[str, torch.device]] = None,
            prefix: Optional[str] = None,
            metric_path: str = None,
            model_path: str = 'TIGER-Lab/VideoScore-v1.1',
            datainfo_path: str = None,
            test_index: int = None,
             **kwargs):
    """
    Args:
        collect_device (Optional[Union[str, torch.device]]): The device to collect the data on.
        prefix (Optional[str]): The prefix to use for the metric.
        metric_path (str): The path to the metric file.
        model_path (str): The path to the model file.
        datainfo_path (str): The path to the datainfo file.
        test_index (int): The index of the test data.
    """
    super().__init__(collect_device=collect_device, prefix=prefix)
    # self.train_index = train_index
    # TODO: ARE THERE PARAMETERS REQUIRED FOR THIS METRIC?
    self.metric_path = metric_path
    self.model_path = model_path
    self.datainfo_path = datainfo_path
    self.test_index = test_index


    self.model = Idefics2ForSequenceClassification.from_pretrained(self.model_path, torch_dtype=torch.bfloat16).eval()
    self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    self.model.to(self.device)

    self.results = []

compute_metrics(results)

Parameters:

Name Type Description Default
results list

The results to compute the metrics from.

required
Source code in aigve/metrics/multi_aspect_metrics/videoscore/videoscore_metric.py
def compute_metrics(self, results: list) -> dict:
    """
    Args:
        results (list): The results to compute the metrics from.
    """
    results = np.array(results)
    mean_scores = np.mean(results, axis=1)

    return {'visual_quailty': results[:, 0].tolist(),
            'temporal_consistency': results[:, 1].tolist(),
            'dynamic_degree': results[:, 2].tolist(),
            'text-to-video_alignment': results[:, 3].tolist(),
            'factual_consistency': results[:, 4].tolist(),
            'summary': {'visual_quality': mean_scores[0], 'temporal_consistency': mean_scores[1],
                        'dynamic_degree': mean_scores[2], 'text-to-video_alignment': mean_scores[3],
                        'factual_consistency': mean_scores[4]}}

process(data_batch, data_samples)

Parameters:

Name Type Description Default
data_batch Any

The data batch to process.

required
data_samples Sequence[dict]

The data samples to process.

required
Source code in aigve/metrics/multi_aspect_metrics/videoscore/videoscore_metric.py
def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None:
    """
    Args:
        data_batch (Any): The data batch to process.
        data_samples (Sequence[dict]): The data samples to process.
    """


    data_batch = {k: v[0].to(self.model.device) for k, v in data_batch.items()}

    with torch.no_grad():
        outputs = self.model(**data_batch)

    logits = outputs.logits.cpu().detach().to(torch.float32).numpy()
    num_aspects = logits.shape[-1]

    aspect_scores = []
    for i in range(num_aspects):
        aspect_scores.append(round(logits[0, i].item(), 3))

    self.results.append(aspect_scores)

Organization of this Library