Skip to content

aigve.datasets

CLIPTempDataset

Bases: Dataset

Source code in aigve/datasets/cliptemp_dataset.py
@DATASETS.register_module()
class CLIPTempDataset(Dataset):
    def __init__(self, processor_name, prompt_dir, video_dir):
        super(CLIPTempDataset, self).__init__()
        self.prompt_dir = prompt_dir
        self.video_dir = video_dir
        self.processor_name = processor_name

        self.processor = AutoProcessor.from_pretrained(self.processor_name)
        self.video_names = self._read_videoname()

    def _read_videoname(self):
        with open(self.prompt_dir, 'r') as reader:
            read_data = json.load(reader)

        video_name_list = []
        for item in read_data["datset_list"]:
            video_name = item['video_path_pd'].strip()
            video_name_list.append(video_name)

        return video_name_list

    def __len__(self):
        return len(self.video_names)-1

    def __getitem__(self, index):
        '''return video frame pairs
        '''
        video_name = self.video_names[index]
        video_path = self.video_dir + video_name
        frames = []
        cap = cv2.VideoCapture(video_path)
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            resized_frame = cv2.resize(frame,(224,224))  # Resize the frame to match the expected input size
            frames.append(resized_frame)

        input_frame_tensor = self.processor(
            images=frames,
            padding=True,
            truncation=True,
            max_length=77,
            return_tensors="pt",
        )['pixel_values']

        return input_frame_tensor

__getitem__(index)

return video frame pairs

Source code in aigve/datasets/cliptemp_dataset.py
def __getitem__(self, index):
    '''return video frame pairs
    '''
    video_name = self.video_names[index]
    video_path = self.video_dir + video_name
    frames = []
    cap = cv2.VideoCapture(video_path)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        resized_frame = cv2.resize(frame,(224,224))  # Resize the frame to match the expected input size
        frames.append(resized_frame)

    input_frame_tensor = self.processor(
        images=frames,
        padding=True,
        truncation=True,
        max_length=77,
        return_tensors="pt",
    )['pixel_values']

    return input_frame_tensor

FidDataset

Bases: Dataset

Dataset for Fréchet Inception Distance (FID) evaluation.

For each sample, this dataset: - Loads both the ground-truth (real) and generated (predicted) videos. - Converts each video into a tensor of shape [T, C, H, W] using OpenCV. - Optionally pads or truncates videos to a fixed number of frames.

Parameters:

Name Type Description Default
video_dir str

Directory containing video files.

required
prompt_dir str

Path to JSON file that lists ground-truth and generated video filenames.

required
max_len int

Maximum number of frames to load per video. Default: 500.

500
if_pad bool

Whether to pad videos to exactly max_len frames. If False, videos can have variable lengths.

False
Source code in aigve/datasets/fid_dataset.py
@DATASETS.register_module()
class FidDataset(Dataset):
    """
    Dataset for Fréchet Inception Distance (FID) evaluation.

    For each sample, this dataset:
        - Loads both the ground-truth (real) and generated (predicted) videos.
        - Converts each video into a tensor of shape [T, C, H, W] using OpenCV.
        - Optionally pads or truncates videos to a fixed number of frames.

    Args:
        video_dir (str): Directory containing video files.
        prompt_dir (str): Path to JSON file that lists ground-truth and generated video filenames.
        max_len (int): Maximum number of frames to load per video. Default: 500.
        if_pad (bool): Whether to pad videos to exactly `max_len` frames. If False, videos can have variable lengths.
    """

    def __init__(self, video_dir, prompt_dir, max_len=500, if_pad=False):
        super(FidDataset, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.video_dir = video_dir
        self.prompt_dir = prompt_dir
        self.max_len = max_len
        self.if_pad = if_pad

        self.gt_video_names, self.gen_video_names = self._read_video_names()

    def _read_video_names(self):
        """Reads video names from the dataset JSON file."""
        with open(self.prompt_dir, 'r') as reader:
            read_data = json.load(reader)
            gt_video_names = [item['video_path_gt'].strip() for item in read_data["data_list"]]
            gen_video_names = [item['video_path_pd'].strip() for item in read_data["data_list"]]
        return gt_video_names, gen_video_names

    def _load_video_tensor(self, video_path: str) -> torch.Tensor:
        """Load a video and return its tensor of shape [T, C, H, W]."""
        assert os.path.exists(video_path), f"Video file not found: {video_path}"
        cap = cv2.VideoCapture(video_path)
        input_frames = []
        frame_count = 0
        while cap.isOpened() and frame_count < self.max_len:
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            input_frames.append(torch.tensor(frame).float())
            frame_count += 1

        cap.release()

        if len(input_frames) == 0:
            raise RuntimeError(f"No valid frames found in {video_path}")

        if self.if_pad:
            num_frames = len(input_frames)
            if num_frames < self.max_len:
                pad_frames = torch.zeros((self.max_len - num_frames, *input_frames[0].shape))
                video_tensor = torch.cat((torch.stack(input_frames), pad_frames), dim=0)
            else:
                video_tensor = torch.stack(input_frames[:self.max_len])
        else:
            video_tensor = torch.stack(input_frames)

        # Convert from [T, H, W, C] to [T, C, H, W]
        return video_tensor.permute(0, 3, 1, 2)

    def __len__(self):
        return len(self.gt_video_names)

    def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor, str, str]:
        """
        Returns:
            Tuple[torch.Tensor, torch.Tensor, str, str]: 
                - Ground-truth (Real) video tensor of shape [T, C, H, W].
                - Generated video tensor of shape [T, C, H, W].
                - Ground-truth video name.
                - Generated video name.
        """
        gt_video_name = self.gt_video_names[index]
        gt_video_path = os.path.join(self.video_dir, gt_video_name)
        gen_video_name = self.gen_video_names[index]
        gen_video_path = os.path.join(self.video_dir, gen_video_name) 

        gt_video_tensor = self._load_video_tensor(gt_video_path)
        gen_video_tensor = self._load_video_tensor(gen_video_path)

        return gt_video_tensor, gen_video_tensor, gt_video_name, gen_video_name

__getitem__(index)

Returns:

Type Description
tuple[Tensor, Tensor, str, str]

Tuple[torch.Tensor, torch.Tensor, str, str]: - Ground-truth (Real) video tensor of shape [T, C, H, W]. - Generated video tensor of shape [T, C, H, W]. - Ground-truth video name. - Generated video name.

Source code in aigve/datasets/fid_dataset.py
def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor, str, str]:
    """
    Returns:
        Tuple[torch.Tensor, torch.Tensor, str, str]: 
            - Ground-truth (Real) video tensor of shape [T, C, H, W].
            - Generated video tensor of shape [T, C, H, W].
            - Ground-truth video name.
            - Generated video name.
    """
    gt_video_name = self.gt_video_names[index]
    gt_video_path = os.path.join(self.video_dir, gt_video_name)
    gen_video_name = self.gen_video_names[index]
    gen_video_path = os.path.join(self.video_dir, gen_video_name) 

    gt_video_tensor = self._load_video_tensor(gt_video_path)
    gen_video_tensor = self._load_video_tensor(gen_video_path)

    return gt_video_tensor, gen_video_tensor, gt_video_name, gen_video_name

GSTVQADataset

Bases: Dataset

Dataset for GSTVQA metric, supports feature extraction using VGG16 or ResNet.

Source code in aigve/datasets/gstvqa_dataset.py
@DATASETS.register_module()
class GSTVQADataset(Dataset):
    """Dataset for GSTVQA metric, supports feature extraction using VGG16 or ResNet."""

    def __init__(self, video_dir, prompt_dir, model_name='vgg16', max_len=500):
        super(GSTVQADataset, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.video_dir = video_dir
        self.prompt_dir = prompt_dir
        self.model_name = model_name
        self.max_len = max_len
        self.feature_extractor = FeatureExtractor(model_name=model_name)

        self.prompts, self.video_names = self._read_prompt_videoname()

    def _read_prompt_videoname(self):
        with open(self.prompt_dir, 'r') as reader:
            read_data = json.load(reader)

        prompt_data_list, video_name_list = [], []
        for item in read_data["data_list"]:
            prompt = item['prompt_gt'].strip()
            video_name = item['video_path_pd'].strip()
            prompt_data_list.append(prompt)
            video_name_list.append(video_name)

        return prompt_data_list, video_name_list

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, index) -> tuple[torch.Tensor, int, str]:
        """
        Returns a tuple of:
            deep_features (torch.Tensor): Shape [max_len, 2944]
                Mean and std features extracted from input frames using the chosen model (VGG16 or ResNet).
                Padded to self.max_len if the number of frames is less.
            num_frames (int): The number of frames in the video.
            video_name (str): The file name for the video.
        """
        video_name = self.video_names[index]
        video_path = os.path.join(self.video_dir, video_name)
        input_frames = []

        cap = cv2.VideoCapture(video_path)
        frame_count = 0

        while cap.isOpened() and frame_count < self.max_len:
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # frame = cv2.resize(frame, self.frame_size)
            input_frames.append(torch.tensor(frame).float())
            frame_count += 1

        cap.release()

        # Pad or truncate frames to max_len
        num_frames = len(input_frames)
        # print('num_frames: ', num_frames)
        if num_frames < 30:
            pad_frames = torch.zeros((30 - num_frames, *input_frames[0].shape))
            input_frames_tensor = torch.cat((torch.stack(input_frames), pad_frames), dim=0)
            num_frames = 30 # Force min frames to be 30 (since two att_frams=15(kernel_size) used in GSTVQA)
        elif num_frames < self.max_len:
            pad_frames = torch.zeros((self.max_len - num_frames, *input_frames[0].shape))
            input_frames_tensor = torch.cat((torch.stack(input_frames), pad_frames), dim=0)
        else:
            input_frames_tensor = torch.stack(input_frames[:self.max_len])
        # print('input_frames_tensor: ', input_frames_tensor.shape) # shape: toy data [max_len, H(512), W(512), C(3)]

        # Convert from [T, H, W, C] to [T, C, H, W]
        input_frames_tensor = input_frames_tensor.permute(0, 3, 1, 2) 

        # Extract features using the chosen model (VGG16 or ResNet)
        with torch.no_grad():
            mean_features, std_features = self.feature_extractor(input_frames_tensor) # Shape: [T, 1472]: [10, 1472]

        # Concatenate to match GSTVQA expected 2944-dim features
        deep_features = torch.cat((mean_features, std_features), dim=1)  # Shape: [T, 2944]

        # Ensure output shape [max_len, 2944] (pad if needed)
        if deep_features.shape[0] < self.max_len:
            pad_size = self.max_len - deep_features.shape[0]
            padding = torch.zeros((pad_size, 2944), device=deep_features.device)
            deep_features = torch.cat((deep_features, padding), dim=0)

        return deep_features, num_frames, video_name

__getitem__(index)

Returns a tuple of

deep_features (torch.Tensor): Shape [max_len, 2944] Mean and std features extracted from input frames using the chosen model (VGG16 or ResNet). Padded to self.max_len if the number of frames is less. num_frames (int): The number of frames in the video. video_name (str): The file name for the video.

Source code in aigve/datasets/gstvqa_dataset.py
def __getitem__(self, index) -> tuple[torch.Tensor, int, str]:
    """
    Returns a tuple of:
        deep_features (torch.Tensor): Shape [max_len, 2944]
            Mean and std features extracted from input frames using the chosen model (VGG16 or ResNet).
            Padded to self.max_len if the number of frames is less.
        num_frames (int): The number of frames in the video.
        video_name (str): The file name for the video.
    """
    video_name = self.video_names[index]
    video_path = os.path.join(self.video_dir, video_name)
    input_frames = []

    cap = cv2.VideoCapture(video_path)
    frame_count = 0

    while cap.isOpened() and frame_count < self.max_len:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        # frame = cv2.resize(frame, self.frame_size)
        input_frames.append(torch.tensor(frame).float())
        frame_count += 1

    cap.release()

    # Pad or truncate frames to max_len
    num_frames = len(input_frames)
    # print('num_frames: ', num_frames)
    if num_frames < 30:
        pad_frames = torch.zeros((30 - num_frames, *input_frames[0].shape))
        input_frames_tensor = torch.cat((torch.stack(input_frames), pad_frames), dim=0)
        num_frames = 30 # Force min frames to be 30 (since two att_frams=15(kernel_size) used in GSTVQA)
    elif num_frames < self.max_len:
        pad_frames = torch.zeros((self.max_len - num_frames, *input_frames[0].shape))
        input_frames_tensor = torch.cat((torch.stack(input_frames), pad_frames), dim=0)
    else:
        input_frames_tensor = torch.stack(input_frames[:self.max_len])
    # print('input_frames_tensor: ', input_frames_tensor.shape) # shape: toy data [max_len, H(512), W(512), C(3)]

    # Convert from [T, H, W, C] to [T, C, H, W]
    input_frames_tensor = input_frames_tensor.permute(0, 3, 1, 2) 

    # Extract features using the chosen model (VGG16 or ResNet)
    with torch.no_grad():
        mean_features, std_features = self.feature_extractor(input_frames_tensor) # Shape: [T, 1472]: [10, 1472]

    # Concatenate to match GSTVQA expected 2944-dim features
    deep_features = torch.cat((mean_features, std_features), dim=1)  # Shape: [T, 2944]

    # Ensure output shape [max_len, 2944] (pad if needed)
    if deep_features.shape[0] < self.max_len:
        pad_size = self.max_len - deep_features.shape[0]
        padding = torch.zeros((pad_size, 2944), device=deep_features.device)
        deep_features = torch.cat((deep_features, padding), dim=0)

    return deep_features, num_frames, video_name

LightVQAPlusDataset

Bases: Dataset

Dataset for LightVQA+. Extracts: - spatial_features (torch.Tensor): Extracted key frames. - temporal_features (torch.Tensor): SlowFast motion features. - BNS_features (torch.Tensor): Brightness & Noise features. - BC_features (torch.Tensor): Temporal CLIP-based brightness contrast features. - video_name (str): Video filename.

Source code in aigve/datasets/lightvqa_plus_dataset.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
@DATASETS.register_module()
class LightVQAPlusDataset(Dataset):
    """
    Dataset for LightVQA+.
    Extracts:
        - spatial_features (torch.Tensor): Extracted key frames.
        - temporal_features (torch.Tensor): SlowFast motion features.
        - BNS_features (torch.Tensor): Brightness & Noise features.
        - BC_features (torch.Tensor): Temporal CLIP-based brightness contrast features.
        - video_name (str): Video filename.
    """

    def __init__(self, video_dir, prompt_dir, min_video_seconds=8):
        super(LightVQAPlusDataset, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.video_dir = video_dir
        self.prompt_dir = prompt_dir
        self.min_video_seconds = min_video_seconds

        self.video_names = self._read_video_names()

        # Load CLIP model for BNS and BC features
        self.clip_model, _ = clip.load("ViT-B/32", device="cpu")
        self.preprocess = transforms.Normalize(
            (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
        )
        self.to_tensor = transforms.ToTensor()

        # CLIP text prompts
        self.text_B = clip.tokenize([  # brightness (B)
            "an underexposed photo", "a slightly underexposed photo",
            "a well-exposed photo", "a slightly overexposed photo", "an overexposed photo"
        ])

        self.text_N = clip.tokenize([ # noise (N)
            "a photo with no noise", "a photo with little noise",
            "a photo with considerable noise", "a photo with serious noise", "a photo with extreme noise"
        ])

        self.submodel_path = os.path.join(os.getcwd(), 'metrics/video_quality_assessment/nn_based/lightvqa_plus')
        if not submodule_exists(self.submodel_path):
            add_git_submodule(
                repo_url='https://github.com/SaMMyCHoo/Light-VQA-plus.git', 
                submodule_path=self.submodel_path
            )
        # original_path = os.path.join(self.submodel_path, "Light-VQA-plus")
        lightvqa_path = os.path.join(self.submodel_path, "Light_VQA_plus")
        # if os.path.exists(original_path) and not os.path.exists(lightvqa_path):
        #     os.rename(original_path, lightvqa_path)
        if lightvqa_path not in sys.path:
            sys.path.insert(0, lightvqa_path)
        # print(sys.path)

        # Load SlowFast model
        slowfast, _ = lazy_import()
        self.slowfast_model = slowfast()

    def _read_video_names(self):
        """Reads video names from the dataset JSON file."""
        with open(self.prompt_dir, 'r') as reader:
            read_data = json.load(reader)
        return [item['video_path_pd'].strip() for item in read_data["data_list"]]

    def __len__(self):
        return len(self.video_names)

    def extract_key_frames(self, video_path):
        """
        Extracts 8 evenly spaced key frames across the entire video duration.

        Args:
            video_path (str): Path to the video file.

        Returns:
            spatial_features (torch.Tensor): Shape [8, 3, 672, 1120] containing 8 key frames.
        """
        cap = cv2.VideoCapture(video_path)
        video_name = os.path.basename(video_path).split('.')[0]

        video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        if video_length >= 8:
            # Select 8 unique frame indices evenly spaced across the entire video
            frame_indices = np.round(np.linspace(0, video_length - 1, num=8)).astype(int)
        else:
            # Select all available frames and repeat the last one to reach 8
            frame_indices = list(range(video_length)) + [video_length - 1] * (8 - video_length)

        spatial_features = torch.zeros([8, 3, 672, 1120])  # Ensure exactly 8 frames
        transform = transforms.Compose([
            transforms.Resize([672, 1120]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        last_valid_frame = None
        for idx, frame_idx in enumerate(frame_indices):
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()
            if ret:
                frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                spatial_features[idx] = transform(frame)
                last_valid_frame = spatial_features[idx]
            elif last_valid_frame is not None:  # If total frames are less than 8, repeat the last valid frame
                spatial_features[idx] = last_valid_frame

        cap.release()
        # print('spatial_features: ', spatial_features.shape) # torch.Size([8, 3, 672, 1120])
        return spatial_features

    def get_global_sf(self, video_path) -> torch.Tensor:
        """Extracts global brightness & noise features across full video.

        Args:
            video_path (str): Path to video file.

        Returns:
            torch.Tensor: Extracted global features (Shape: [8, 150]).
        """
        cap = cv2.VideoCapture(video_path)
        video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        # print('video_length: ', video_length)  # 16

        frames = []
        for _ in range(video_length):
            ret, frame = cap.read()
            if ret:
                frame = cv2.resize(frame, (1120, 672))
                frames.append(frame)
        cap.release()

        if not frames:
            raise ValueError(f"Failed to extract frames from {video_path}")

        res = []
        length = len(frames)
        now = 0
        interval = 10  # Process 10 frames at a time
        while now + interval - 1 < length:
            final = [self.to_tensor(Image.fromarray(cv2.cvtColor(frames[i + now], cv2.COLOR_BGR2RGB)))
                    for i in range(interval)]

            # Step 1: Convert to tensor batch
            images = torch.stack(final, dim=0)  # Shape: [10, 3, 672, 1120]

            # Step 2: Unfold into patches (Strictly following GET_SF)
            images = images.unfold(2, 224, 224).unfold(3, 224, 224)  # Shape: [10, 3, 3, 5, 224, 224]
            images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [10, 5, 3, 3, 224, 224]
            images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [10, 15, 3, 224, 224]
            images = images.view(-1, 3, 224, 224)  # Shape: [150, 3, 224, 224]
            images = self.preprocess(images)  # Normalize for CLIP
            # print('images get_global_sf: ', images.shape) # torch.Size([10*15, 3, 224, 224])

            # Step 3: Extract features using CLIP
            with torch.no_grad():
                logits_N, _ = self.clip_model(images, self.text_N)
                logits_B, _ = self.clip_model(images, self.text_B)

            tmp_N = logits_N.softmax(dim=-1).view(interval, -1) * 10
            tmp_B = logits_B.softmax(dim=-1).view(interval, -1) * 10
            # print('tmp_N get_global_sf', tmp_N.shape) # torch.Size([10, 75])
            # print('tmp_B get_global_sf', tmp_B.shape) # torch.Size([10, 75])
            res.append(torch.cat([tmp_N, tmp_B], dim=1))
            now += interval

        # Handle remaining frames
        if length > now:
            final = [self.to_tensor(Image.fromarray(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB)))
                    for i in range(now, length)]

            images = torch.stack(final, dim=0)  # Shape: [remaining(6), 3, 672, 1120]
            images = images.unfold(2, 224, 224).unfold(3, 224, 224)  # Shape: [remaining, 3, 3, 5, 224, 224]
            images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [remaining, 5, 3, 3, 224, 224]
            images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [remaining, 15, 3, 224, 224]
            images = images.view(-1, 3, 224, 224)  # Shape: [remaining*15, 3, 224, 224]
            images = self.preprocess(images)

            with torch.no_grad():
                logits_N, _ = self.clip_model(images, self.text_N) # Shape: [remaining, 5(num_text_prompts)]
                logits_B, _ = self.clip_model(images, self.text_B) # Shape: [remaining, 5]
                # print('logits_N last get_global_sf', logits_N.shape) # torch.Size([6*15, 5])
                # print('logits_B last get_global_sf', logits_B.shape) #torch.Size([6*15, 5])

            tmp_N = logits_N.softmax(dim=-1).view(length - now, -1) * 10 # Shape: [remaining, 75]
            tmp_B = logits_B.softmax(dim=-1).view(length - now, -1) * 10 # Shape: [remaining, 75]
            # print('tmp_N last get_global_sf', tmp_N.shape)  # torch.Size([6, 75])
            # print('tmp_B last get_global_sf', tmp_B.shape)  # torch.Size([6, 75])

            res.append(torch.cat([tmp_N, tmp_B], dim=1))

        res = torch.cat(res, dim=0)  # Shape: [length, 150]
        # print('res: ', res.shape)  # torch.Size([16, 150]) for toy dataset

        # Step 4: Aggregate into 8 time slots
        chunk_size = length // 8
        final_res = [
            torch.mean(res[i * chunk_size: (i + 1) * chunk_size], dim=0) if i < 7 else torch.mean(res[7 * chunk_size:], dim=0)
            for i in range(8)
        ]

        return torch.stack(final_res, dim=0)  # Shape: [8, 150]

    def extract_bns_features(self, video_path):
        """Extracts Brightness & Noise Sensitivity (BNS) features using CLIP.
        Local Feature Extraction (res1) → Uses 8 key frames
        Global Feature Extraction (res2) → Uses all frames

        Args:
            video_path (str): Path to the video file.

        Returns:
            spatial_features (torch.Tensor): Extracted 8 evenly spaced key frames across the entire video duration.
                Shape [8, 3, 672, 1120] containing 8 key frames.
            final_res (torch.Tensor): Extracted BNS feature (Shape: [8, 300]).
        """
        # Local Feature Extraction Step 1: Extract key frames
        spatial_features = self.extract_key_frames(video_path) # Shape: [8, 3, 672, 1120]

        # Step 2: Apply unfolding transformation (Strictly following GET_S_F)
        images = spatial_features.unfold(2, 224, 224).unfold(3, 224, 224)  # Break into patches. Shape: [8, 3, 3, 5, 224, 224]
        images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [8, 5, 3, 3, 224, 224]
        images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [8, 15, 3, 224, 224]
        images = images.view(-1, 3, 224, 224)  # Shape: [120, 3, 224, 224]
        images = self.preprocess(images)  # Normalize for CLIP
        # print('images: ', images.shape) # torch.Size([120, 3, 224, 224])
        # print(images.device)
        # print(self.text_N.device)

        # Step 3: Pass through CLIP
        with torch.no_grad():
            logits_N, _ = self.clip_model(images, self.text_N)
            logits_B, _ = self.clip_model(images, self.text_B)

        res_N = logits_N.softmax(dim=-1).view(8, -1) * 10
        # print('res_N: ', res_N.shape) # torch.Size([8, 75])
        res_B = logits_B.softmax(dim=-1).view(8, -1) * 10
        # print('res_B: ', res_N.shape) # torch.Size([8, 75])
        res1 = torch.cat((res_N, res_B), dim=1)
        # print('res1: ', res1.shape) # torch.Size([8, 150])

        # Global Feature Extraction (GET_SF Equivalent)
        res2 = self.get_global_sf(video_path)
        # print('res2: ', res2.shape) # res2:  torch.Size([8, 150])

        # Split & Combine Features
        Nl, Bl = torch.split(res1, 75, dim=1)
        Ng, Bg = torch.split(res2, 75, dim=1)
        final_res = torch.cat([Nl, Ng, Bl, Bg], dim=1)
        # print('final_res: ', final_res.shape)

        return spatial_features, final_res  # Shape: [8, 300]

    def extract_bc_features(self, video_path) -> torch.Tensor:
        """
        Extracts Brightness Consistency features using CLIP-based temporal processing.

        Returns:
            torch.Tensor: Extracted BC feature (Shape: [8, final_dim]).
        """

        cap = cv2.VideoCapture(video_path)
        video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        frames = []
        for _ in range(video_length):
            ret, frame = cap.read()
            if ret:
                frame = cv2.resize(frame, (1120, 672))
                frames.append(frame)
        cap.release()

        if not frames:
            raise ValueError(f"Failed to extract frames from {video_path}")

        res = []
        now = 0
        interval = 10  # Process 10 frames at a time
        length = len(frames)

        # Step 1: Extract CLIP Features at Fixed Intervals
        while now + interval - 1 < length:
            batch = [self.to_tensor(Image.fromarray(cv2.cvtColor(frames[i + now], cv2.COLOR_BGR2RGB)))
                    for i in range(interval)]
            images = torch.stack(batch, dim=0)
            images = images.unfold(2, 224, 224).unfold(3, 224, 224)  # Shape: [10, 3, 3, 5, 224, 224]
            images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [10, 5, 3, 3, 224, 224]
            images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [10, 15, 3, 224, 224]
            images = images.view(-1, 3, 224, 224)  # Shape: [10*15, 3, 224, 224]
            images = self.preprocess(images)
            # print('images extract_bc_features', images.shape) # torch.Size([150, 3, 224, 224])

            with torch.no_grad():
                logits, _ = self.clip_model(images, self.text_B)

            tmp = logits.softmax(dim=-1) * 10
            res.append(tmp)
            now += interval

        # Handle Remaining Frames
        if length > now:
            batch = [self.to_tensor(Image.fromarray(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB)))
                    for i in range(now, length)]
            images = torch.stack(batch, dim=0)
            images = images.unfold(2, 224, 224).unfold(3, 224, 224)  # Shape: [remaining(6), 3, 3, 5, 224, 224]
            images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [remaining, 5, 3, 3, 224, 224]
            images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [remaining, 15, 3, 224, 224]
            images = images.view(-1, 3, 224, 224)  # Shape: [remaining, 15, 3, 224, 224]
            images = self.preprocess(images)
            # print('images: ', images.shape) #  torch.Size([6*15, 3, 224, 224])

            with torch.no_grad():
                logits, _ = self.clip_model(images, self.text_B)

            tmp = logits.softmax(dim=-1) * 10
            res.append(tmp)

        res = torch.cat(res, dim=0)  # Shape: [length, 5]
        # print('res extract_bc_features: ', res.shape) # torch.Size([150+90, 5])

        # Step 2: Multi-Scale Variance Computation: downsample frames steps
        # smaller step: Captures fast, fine-grained changes.
        # larger step:  Captures slow, long-term trends.
        final_res = []
        for step in [1, 2, 4, 8]:  # Multi-scale temporal steps
            chunk_number = 8 // step
            chunk_size = length // chunk_number
            chunks = []
            for i in range(chunk_number):
                if i < chunk_number - 1:
                    chunk = res[i * chunk_size : (i + 1) * chunk_size, :]
                else:
                    chunk = res[(chunk_number - 1) * chunk_size:, :]  # Handle remaining frames
                tmp = []
                for j in range(step):
                    temp = chunk[j::step, :]  
                    tmp.append(torch.var(temp.float(), dim=0))  # Variance computation
                chunks.append(tmp)  # final chunks len: 8; 4; 2; 1 
            final_res.append(chunks) # final final_res len: 4

        # Step 3: Aggregate Multi-Scale Features
        temp = []
        for i in range(8):  # Aggregate temporal information across 8 time slots
            temp.append(torch.cat(final_res[0][i]                                                # variance for step size = 1
                                + [torch.mean(torch.stack(final_res[1][i // 2], dim=0), dim=0)]  # for step size = 2
                                + [torch.mean(torch.stack(final_res[2][i // 4], dim=0), dim=0)]  # Every 4 slots share the same value.
                                + [torch.mean(torch.stack(final_res[3][i // 8], dim=0), dim=0)]  # for step size = 8
                                , dim=0))

        final_res = torch.stack(temp, dim=0)  # Shape: [8, final_dim]  
        # print('final_res extract_bc_featuresx: ', final_res.shape) # torch.Size([8, 20])

        return final_res

    def extract_temporal_features(self, video_path) -> torch.Tensor:
        """Extracts SlowFast motion features on the entire video segment.

        Args:
            video_path (str): Path to the video file.

        Returns:
            torch.Tensor: Extracted motion features (Shape: [1, feature_dim(2304)]).
        """
        cap = cv2.VideoCapture(video_path)
        video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_indices = np.round(np.linspace(0, video_length - 1, num=8)).astype(int)

        transform = transforms.Compose([
            transforms.Resize([224, 224]),  # Match SlowFast input size
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])  # Original normalization
        ])

        frames = []
        for idx in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                frames.append(transform(frame))  # Resize & normalize
        cap.release()

        if len(frames) < 8:
            raise ValueError(f"Insufficient frames in {video_path}, expected 8.")

        video_tensor = torch.stack(frames, dim=0)  # Shape: [8, 3, 224, 224]

        # Prepare for SlowFast input
        video_tensor = video_tensor.unsqueeze(0)  # Add batch dimension: [1, 8, 3, 224, 224]
        video_tensor = video_tensor.permute(0, 2, 1, 3, 4)  # Shape: [1, 3, 8, 224, 224]

        # Pack pathways for SlowFast model
        _, pack_pathway_output = lazy_import()
        inputs = pack_pathway_output(video_tensor, device='cpu')
        # print('inputs len: ', len(inputs))
        # print('inputs[0]: ', inputs[0].shape) # torch.Size([1, 3, 2, 224, 224])
        # print('inputs[1]: ', inputs[1].shape) # torch.Size([1, 3, 8, 224, 224])

        # Extract features using SlowFast
        with torch.no_grad():
            slow_feature, fast_feature = self.slowfast_model(inputs)

        # print('slow_feature extract_temporal_features: ', slow_feature.shape) # torch.Size([1, 2048, 1, 1, 1])
        # print('fast_feature extract_temporal_features: ', fast_feature.shape) # torch.Size([1, 256, 1, 1, 1])

        # Concatenate slow and fast features
        features = torch.cat([slow_feature, fast_feature], dim=1).squeeze(-1).squeeze(-1).squeeze(-1)
        # print('features extract_temporal_features: ', features.shape) # torch.Size([1, 2304])

        return features

    def __getitem__(self, index):
        """
        Returns:
            spatial_features (torch.Tensor): Spatial features. Shape: [8, 3, 672, 1120].
            bns_features (torch.Tensor): Brightness & Noise features. Shape: [8, 300].
            (bc_features (torch.Tensor): Temporal brightness contrast features. Shape: [8, final_dim].)
            temporal_features (torch.Tensor): SlowFast motion features. Shape: [1, feature_dim(2304)]
            video_name (str): Video filename.
        """
        video_name = self.video_names[index]
        video_path = os.path.join(self.video_dir, video_name)

        spatial_features, bns_features = self.extract_bns_features(video_path)
        bc_features = self.extract_bc_features(video_path)
        temporal_features = self.extract_temporal_features(video_path)

        return spatial_features, temporal_features, bns_features, bc_features, video_name

__getitem__(index)

Returns:

Name Type Description
spatial_features Tensor

Spatial features. Shape: [8, 3, 672, 1120].

bns_features Tensor

Brightness & Noise features. Shape: [8, 300].

bc_features (torch.Tensor

Temporal brightness contrast features. Shape: [8, final_dim].)

temporal_features Tensor

SlowFast motion features. Shape: [1, feature_dim(2304)]

video_name str

Video filename.

Source code in aigve/datasets/lightvqa_plus_dataset.py
def __getitem__(self, index):
    """
    Returns:
        spatial_features (torch.Tensor): Spatial features. Shape: [8, 3, 672, 1120].
        bns_features (torch.Tensor): Brightness & Noise features. Shape: [8, 300].
        (bc_features (torch.Tensor): Temporal brightness contrast features. Shape: [8, final_dim].)
        temporal_features (torch.Tensor): SlowFast motion features. Shape: [1, feature_dim(2304)]
        video_name (str): Video filename.
    """
    video_name = self.video_names[index]
    video_path = os.path.join(self.video_dir, video_name)

    spatial_features, bns_features = self.extract_bns_features(video_path)
    bc_features = self.extract_bc_features(video_path)
    temporal_features = self.extract_temporal_features(video_path)

    return spatial_features, temporal_features, bns_features, bc_features, video_name

extract_bc_features(video_path)

Extracts Brightness Consistency features using CLIP-based temporal processing.

Returns:

Type Description
Tensor

torch.Tensor: Extracted BC feature (Shape: [8, final_dim]).

Source code in aigve/datasets/lightvqa_plus_dataset.py
def extract_bc_features(self, video_path) -> torch.Tensor:
    """
    Extracts Brightness Consistency features using CLIP-based temporal processing.

    Returns:
        torch.Tensor: Extracted BC feature (Shape: [8, final_dim]).
    """

    cap = cv2.VideoCapture(video_path)
    video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    frames = []
    for _ in range(video_length):
        ret, frame = cap.read()
        if ret:
            frame = cv2.resize(frame, (1120, 672))
            frames.append(frame)
    cap.release()

    if not frames:
        raise ValueError(f"Failed to extract frames from {video_path}")

    res = []
    now = 0
    interval = 10  # Process 10 frames at a time
    length = len(frames)

    # Step 1: Extract CLIP Features at Fixed Intervals
    while now + interval - 1 < length:
        batch = [self.to_tensor(Image.fromarray(cv2.cvtColor(frames[i + now], cv2.COLOR_BGR2RGB)))
                for i in range(interval)]
        images = torch.stack(batch, dim=0)
        images = images.unfold(2, 224, 224).unfold(3, 224, 224)  # Shape: [10, 3, 3, 5, 224, 224]
        images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [10, 5, 3, 3, 224, 224]
        images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [10, 15, 3, 224, 224]
        images = images.view(-1, 3, 224, 224)  # Shape: [10*15, 3, 224, 224]
        images = self.preprocess(images)
        # print('images extract_bc_features', images.shape) # torch.Size([150, 3, 224, 224])

        with torch.no_grad():
            logits, _ = self.clip_model(images, self.text_B)

        tmp = logits.softmax(dim=-1) * 10
        res.append(tmp)
        now += interval

    # Handle Remaining Frames
    if length > now:
        batch = [self.to_tensor(Image.fromarray(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB)))
                for i in range(now, length)]
        images = torch.stack(batch, dim=0)
        images = images.unfold(2, 224, 224).unfold(3, 224, 224)  # Shape: [remaining(6), 3, 3, 5, 224, 224]
        images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [remaining, 5, 3, 3, 224, 224]
        images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [remaining, 15, 3, 224, 224]
        images = images.view(-1, 3, 224, 224)  # Shape: [remaining, 15, 3, 224, 224]
        images = self.preprocess(images)
        # print('images: ', images.shape) #  torch.Size([6*15, 3, 224, 224])

        with torch.no_grad():
            logits, _ = self.clip_model(images, self.text_B)

        tmp = logits.softmax(dim=-1) * 10
        res.append(tmp)

    res = torch.cat(res, dim=0)  # Shape: [length, 5]
    # print('res extract_bc_features: ', res.shape) # torch.Size([150+90, 5])

    # Step 2: Multi-Scale Variance Computation: downsample frames steps
    # smaller step: Captures fast, fine-grained changes.
    # larger step:  Captures slow, long-term trends.
    final_res = []
    for step in [1, 2, 4, 8]:  # Multi-scale temporal steps
        chunk_number = 8 // step
        chunk_size = length // chunk_number
        chunks = []
        for i in range(chunk_number):
            if i < chunk_number - 1:
                chunk = res[i * chunk_size : (i + 1) * chunk_size, :]
            else:
                chunk = res[(chunk_number - 1) * chunk_size:, :]  # Handle remaining frames
            tmp = []
            for j in range(step):
                temp = chunk[j::step, :]  
                tmp.append(torch.var(temp.float(), dim=0))  # Variance computation
            chunks.append(tmp)  # final chunks len: 8; 4; 2; 1 
        final_res.append(chunks) # final final_res len: 4

    # Step 3: Aggregate Multi-Scale Features
    temp = []
    for i in range(8):  # Aggregate temporal information across 8 time slots
        temp.append(torch.cat(final_res[0][i]                                                # variance for step size = 1
                            + [torch.mean(torch.stack(final_res[1][i // 2], dim=0), dim=0)]  # for step size = 2
                            + [torch.mean(torch.stack(final_res[2][i // 4], dim=0), dim=0)]  # Every 4 slots share the same value.
                            + [torch.mean(torch.stack(final_res[3][i // 8], dim=0), dim=0)]  # for step size = 8
                            , dim=0))

    final_res = torch.stack(temp, dim=0)  # Shape: [8, final_dim]  
    # print('final_res extract_bc_featuresx: ', final_res.shape) # torch.Size([8, 20])

    return final_res

extract_bns_features(video_path)

Extracts Brightness & Noise Sensitivity (BNS) features using CLIP. Local Feature Extraction (res1) → Uses 8 key frames Global Feature Extraction (res2) → Uses all frames

Parameters:

Name Type Description Default
video_path str

Path to the video file.

required

Returns:

Name Type Description
spatial_features Tensor

Extracted 8 evenly spaced key frames across the entire video duration. Shape [8, 3, 672, 1120] containing 8 key frames.

final_res Tensor

Extracted BNS feature (Shape: [8, 300]).

Source code in aigve/datasets/lightvqa_plus_dataset.py
def extract_bns_features(self, video_path):
    """Extracts Brightness & Noise Sensitivity (BNS) features using CLIP.
    Local Feature Extraction (res1) → Uses 8 key frames
    Global Feature Extraction (res2) → Uses all frames

    Args:
        video_path (str): Path to the video file.

    Returns:
        spatial_features (torch.Tensor): Extracted 8 evenly spaced key frames across the entire video duration.
            Shape [8, 3, 672, 1120] containing 8 key frames.
        final_res (torch.Tensor): Extracted BNS feature (Shape: [8, 300]).
    """
    # Local Feature Extraction Step 1: Extract key frames
    spatial_features = self.extract_key_frames(video_path) # Shape: [8, 3, 672, 1120]

    # Step 2: Apply unfolding transformation (Strictly following GET_S_F)
    images = spatial_features.unfold(2, 224, 224).unfold(3, 224, 224)  # Break into patches. Shape: [8, 3, 3, 5, 224, 224]
    images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [8, 5, 3, 3, 224, 224]
    images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [8, 15, 3, 224, 224]
    images = images.view(-1, 3, 224, 224)  # Shape: [120, 3, 224, 224]
    images = self.preprocess(images)  # Normalize for CLIP
    # print('images: ', images.shape) # torch.Size([120, 3, 224, 224])
    # print(images.device)
    # print(self.text_N.device)

    # Step 3: Pass through CLIP
    with torch.no_grad():
        logits_N, _ = self.clip_model(images, self.text_N)
        logits_B, _ = self.clip_model(images, self.text_B)

    res_N = logits_N.softmax(dim=-1).view(8, -1) * 10
    # print('res_N: ', res_N.shape) # torch.Size([8, 75])
    res_B = logits_B.softmax(dim=-1).view(8, -1) * 10
    # print('res_B: ', res_N.shape) # torch.Size([8, 75])
    res1 = torch.cat((res_N, res_B), dim=1)
    # print('res1: ', res1.shape) # torch.Size([8, 150])

    # Global Feature Extraction (GET_SF Equivalent)
    res2 = self.get_global_sf(video_path)
    # print('res2: ', res2.shape) # res2:  torch.Size([8, 150])

    # Split & Combine Features
    Nl, Bl = torch.split(res1, 75, dim=1)
    Ng, Bg = torch.split(res2, 75, dim=1)
    final_res = torch.cat([Nl, Ng, Bl, Bg], dim=1)
    # print('final_res: ', final_res.shape)

    return spatial_features, final_res  # Shape: [8, 300]

extract_key_frames(video_path)

Extracts 8 evenly spaced key frames across the entire video duration.

Parameters:

Name Type Description Default
video_path str

Path to the video file.

required

Returns:

Name Type Description
spatial_features Tensor

Shape [8, 3, 672, 1120] containing 8 key frames.

Source code in aigve/datasets/lightvqa_plus_dataset.py
def extract_key_frames(self, video_path):
    """
    Extracts 8 evenly spaced key frames across the entire video duration.

    Args:
        video_path (str): Path to the video file.

    Returns:
        spatial_features (torch.Tensor): Shape [8, 3, 672, 1120] containing 8 key frames.
    """
    cap = cv2.VideoCapture(video_path)
    video_name = os.path.basename(video_path).split('.')[0]

    video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    if video_length >= 8:
        # Select 8 unique frame indices evenly spaced across the entire video
        frame_indices = np.round(np.linspace(0, video_length - 1, num=8)).astype(int)
    else:
        # Select all available frames and repeat the last one to reach 8
        frame_indices = list(range(video_length)) + [video_length - 1] * (8 - video_length)

    spatial_features = torch.zeros([8, 3, 672, 1120])  # Ensure exactly 8 frames
    transform = transforms.Compose([
        transforms.Resize([672, 1120]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    last_valid_frame = None
    for idx, frame_idx in enumerate(frame_indices):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        if ret:
            frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            spatial_features[idx] = transform(frame)
            last_valid_frame = spatial_features[idx]
        elif last_valid_frame is not None:  # If total frames are less than 8, repeat the last valid frame
            spatial_features[idx] = last_valid_frame

    cap.release()
    # print('spatial_features: ', spatial_features.shape) # torch.Size([8, 3, 672, 1120])
    return spatial_features

extract_temporal_features(video_path)

Extracts SlowFast motion features on the entire video segment.

Parameters:

Name Type Description Default
video_path str

Path to the video file.

required

Returns:

Type Description
Tensor

torch.Tensor: Extracted motion features (Shape: [1, feature_dim(2304)]).

Source code in aigve/datasets/lightvqa_plus_dataset.py
def extract_temporal_features(self, video_path) -> torch.Tensor:
    """Extracts SlowFast motion features on the entire video segment.

    Args:
        video_path (str): Path to the video file.

    Returns:
        torch.Tensor: Extracted motion features (Shape: [1, feature_dim(2304)]).
    """
    cap = cv2.VideoCapture(video_path)
    video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_indices = np.round(np.linspace(0, video_length - 1, num=8)).astype(int)

    transform = transforms.Compose([
        transforms.Resize([224, 224]),  # Match SlowFast input size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])  # Original normalization
    ])

    frames = []
    for idx in frame_indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
        ret, frame = cap.read()
        if ret:
            frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            frames.append(transform(frame))  # Resize & normalize
    cap.release()

    if len(frames) < 8:
        raise ValueError(f"Insufficient frames in {video_path}, expected 8.")

    video_tensor = torch.stack(frames, dim=0)  # Shape: [8, 3, 224, 224]

    # Prepare for SlowFast input
    video_tensor = video_tensor.unsqueeze(0)  # Add batch dimension: [1, 8, 3, 224, 224]
    video_tensor = video_tensor.permute(0, 2, 1, 3, 4)  # Shape: [1, 3, 8, 224, 224]

    # Pack pathways for SlowFast model
    _, pack_pathway_output = lazy_import()
    inputs = pack_pathway_output(video_tensor, device='cpu')
    # print('inputs len: ', len(inputs))
    # print('inputs[0]: ', inputs[0].shape) # torch.Size([1, 3, 2, 224, 224])
    # print('inputs[1]: ', inputs[1].shape) # torch.Size([1, 3, 8, 224, 224])

    # Extract features using SlowFast
    with torch.no_grad():
        slow_feature, fast_feature = self.slowfast_model(inputs)

    # print('slow_feature extract_temporal_features: ', slow_feature.shape) # torch.Size([1, 2048, 1, 1, 1])
    # print('fast_feature extract_temporal_features: ', fast_feature.shape) # torch.Size([1, 256, 1, 1, 1])

    # Concatenate slow and fast features
    features = torch.cat([slow_feature, fast_feature], dim=1).squeeze(-1).squeeze(-1).squeeze(-1)
    # print('features extract_temporal_features: ', features.shape) # torch.Size([1, 2304])

    return features

get_global_sf(video_path)

Extracts global brightness & noise features across full video.

Parameters:

Name Type Description Default
video_path str

Path to video file.

required

Returns:

Type Description
Tensor

torch.Tensor: Extracted global features (Shape: [8, 150]).

Source code in aigve/datasets/lightvqa_plus_dataset.py
def get_global_sf(self, video_path) -> torch.Tensor:
    """Extracts global brightness & noise features across full video.

    Args:
        video_path (str): Path to video file.

    Returns:
        torch.Tensor: Extracted global features (Shape: [8, 150]).
    """
    cap = cv2.VideoCapture(video_path)
    video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    # print('video_length: ', video_length)  # 16

    frames = []
    for _ in range(video_length):
        ret, frame = cap.read()
        if ret:
            frame = cv2.resize(frame, (1120, 672))
            frames.append(frame)
    cap.release()

    if not frames:
        raise ValueError(f"Failed to extract frames from {video_path}")

    res = []
    length = len(frames)
    now = 0
    interval = 10  # Process 10 frames at a time
    while now + interval - 1 < length:
        final = [self.to_tensor(Image.fromarray(cv2.cvtColor(frames[i + now], cv2.COLOR_BGR2RGB)))
                for i in range(interval)]

        # Step 1: Convert to tensor batch
        images = torch.stack(final, dim=0)  # Shape: [10, 3, 672, 1120]

        # Step 2: Unfold into patches (Strictly following GET_SF)
        images = images.unfold(2, 224, 224).unfold(3, 224, 224)  # Shape: [10, 3, 3, 5, 224, 224]
        images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [10, 5, 3, 3, 224, 224]
        images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [10, 15, 3, 224, 224]
        images = images.view(-1, 3, 224, 224)  # Shape: [150, 3, 224, 224]
        images = self.preprocess(images)  # Normalize for CLIP
        # print('images get_global_sf: ', images.shape) # torch.Size([10*15, 3, 224, 224])

        # Step 3: Extract features using CLIP
        with torch.no_grad():
            logits_N, _ = self.clip_model(images, self.text_N)
            logits_B, _ = self.clip_model(images, self.text_B)

        tmp_N = logits_N.softmax(dim=-1).view(interval, -1) * 10
        tmp_B = logits_B.softmax(dim=-1).view(interval, -1) * 10
        # print('tmp_N get_global_sf', tmp_N.shape) # torch.Size([10, 75])
        # print('tmp_B get_global_sf', tmp_B.shape) # torch.Size([10, 75])
        res.append(torch.cat([tmp_N, tmp_B], dim=1))
        now += interval

    # Handle remaining frames
    if length > now:
        final = [self.to_tensor(Image.fromarray(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB)))
                for i in range(now, length)]

        images = torch.stack(final, dim=0)  # Shape: [remaining(6), 3, 672, 1120]
        images = images.unfold(2, 224, 224).unfold(3, 224, 224)  # Shape: [remaining, 3, 3, 5, 224, 224]
        images = images.permute(0, 3, 2, 1, 4, 5).contiguous()  # Shape: [remaining, 5, 3, 3, 224, 224]
        images = images.reshape(-1, 15, 3, 224, 224)  # Shape: [remaining, 15, 3, 224, 224]
        images = images.view(-1, 3, 224, 224)  # Shape: [remaining*15, 3, 224, 224]
        images = self.preprocess(images)

        with torch.no_grad():
            logits_N, _ = self.clip_model(images, self.text_N) # Shape: [remaining, 5(num_text_prompts)]
            logits_B, _ = self.clip_model(images, self.text_B) # Shape: [remaining, 5]
            # print('logits_N last get_global_sf', logits_N.shape) # torch.Size([6*15, 5])
            # print('logits_B last get_global_sf', logits_B.shape) #torch.Size([6*15, 5])

        tmp_N = logits_N.softmax(dim=-1).view(length - now, -1) * 10 # Shape: [remaining, 75]
        tmp_B = logits_B.softmax(dim=-1).view(length - now, -1) * 10 # Shape: [remaining, 75]
        # print('tmp_N last get_global_sf', tmp_N.shape)  # torch.Size([6, 75])
        # print('tmp_B last get_global_sf', tmp_B.shape)  # torch.Size([6, 75])

        res.append(torch.cat([tmp_N, tmp_B], dim=1))

    res = torch.cat(res, dim=0)  # Shape: [length, 150]
    # print('res: ', res.shape)  # torch.Size([16, 150]) for toy dataset

    # Step 4: Aggregate into 8 time slots
    chunk_size = length // 8
    final_res = [
        torch.mean(res[i * chunk_size: (i + 1) * chunk_size], dim=0) if i < 7 else torch.mean(res[7 * chunk_size:], dim=0)
        for i in range(8)
    ]

    return torch.stack(final_res, dim=0)  # Shape: [8, 150]

SimpleVQADataset

Bases: Dataset

Dataset for SimpleVQA. Each sample returns: - spatial_features (torch.Tensor): Extracted spatial frames. - motion_features (torch.Tensor): Extracted motion-based clips. - video_name (str): Video filename.

Source code in aigve/datasets/simplevqa_dataset.py
@DATASETS.register_module()
class SimpleVQADataset(Dataset):
    """
    Dataset for SimpleVQA.
    Each sample returns:
        - spatial_features (torch.Tensor): Extracted spatial frames.
        - motion_features (torch.Tensor): Extracted motion-based clips.
        - video_name (str): Video filename.
    """

    def __init__(self, video_dir, prompt_dir, min_video_seconds=8):
        super(SimpleVQADataset, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.video_dir = video_dir
        self.prompt_dir = prompt_dir
        self.min_video_seconds = min_video_seconds

        self.prompts, self.video_names = self._read_prompt_videoname()

    def _read_prompt_videoname(self):
        with open(self.prompt_dir, 'r') as reader:
            read_data = json.load(reader)

        prompt_data_list, video_name_list = [], []
        for item in read_data["data_list"]:
            prompt = item['prompt_gt'].strip()
            video_name = item['video_path_pd'].strip()
            prompt_data_list.append(prompt)
            video_name_list.append(video_name)

        return prompt_data_list, video_name_list

    def __len__(self):
        return len(self.prompts)

    def video_processing_spatial(self, video_path):
        """
        Extracts spatial frames with proper resizing and normalization.
            - Key frame extraction: It selects 1 frame per second.
            - Standard input size: It resizes frames to 448 * 448 (after an initial resize to 520).
        Return:
            transformed_video (torch.Tensor): shape[video_length_read, 3, 448, 448]. 
                `video_length_read` is total seconds of the video (though 2 for toy dataset) with minium 8 (i.e. min_video_seconds).
            video_name (str)
        """
        video_capture = cv2.VideoCapture(video_path)
        video_name = os.path.basename(video_path)
        video_length = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
        video_frame_rate = int(round(video_capture.get(cv2.CAP_PROP_FPS)))

        # Compute the number of total seconds of the video
        video_length_read = int(video_length/video_frame_rate) # math.ceil()
        # print('video_length_read (s): ', video_length_read)
        transformations = transforms.Compose([
            transforms.Resize(520),
            transforms.CenterCrop(448),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standard ImageNet normalization
        ])
        transformed_video = torch.zeros([max(video_length_read, self.min_video_seconds), 3, 448, 448])

        video_read_index = 0
        frame_idx = 0
        for i in range(video_length):
            has_frames, frame = video_capture.read()
            if has_frames:
                # Key frames extraction
                if (video_read_index < video_length_read) and (frame_idx % video_frame_rate == 0):
                    read_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                    read_frame = transformations(read_frame)
                    transformed_video[video_read_index] = read_frame
                    video_read_index += 1
                frame_idx += 1

        # Pads remaining frames by repeating the last available frame.
        if video_read_index < self.min_video_seconds:
            for i in range(video_read_index, self.min_video_seconds):
                transformed_video[i] = transformed_video[video_read_index - 1]

        video_capture.release()
        return transformed_video, video_name

    def video_processing_motion(self, video_path):
        """
        Extracts motion-based clips suitable for SlowFast.
            - Standard input size: It resizes frames to 224 * 224.
            - Motion-based clips: Processes at leaset 8-second clips, select 32 consecutive frames from each second.
        Return:
            transformed_video_all (List[torch.Tensor]): Tensor shape[video_length_clip(32), 3, 224, 224]. 
                Len(List) is total seconds of the video, with minium 8.
            video_name (str)
        """
        video_capture = cv2.VideoCapture(video_path)
        video_name = os.path.basename(video_path)
        video_length = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
        video_frame_rate = int(round(video_capture.get(cv2.CAP_PROP_FPS)))

        transform = transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225]) # General purpose
        ])
        transformed_frame_all = torch.zeros([video_length, 3, 224, 224])
        video_read_index = 0
        for i in range(video_length): # All frames extraction
            has_frames, frame = video_capture.read()
            if has_frames:
                read_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                read_frame = transform(read_frame)
                transformed_frame_all[video_read_index] = read_frame
                video_read_index += 1

        # Pads remaining frames by repeating the last available frame.
        if video_read_index < video_length: 
            for i in range(video_read_index, video_length):
                transformed_frame_all[i] = transformed_frame_all[video_read_index - 1]

        video_capture.release()

        # Compute the number of total seconds of the video
        video_clip = int(video_length/video_frame_rate)
        # print('video_clip (s): ', video_clip)
        video_length_clip = 32
        transformed_video_all = []

        # Extract motion-based clips: select 32 consecutive frames from each second
        for i in range(video_clip):
            transformed_video = torch.zeros([video_length_clip, 3, 224, 224])
            if (i * video_frame_rate + video_length_clip) <= video_length: # if the clip can be fully extracted, select 32 consecutive frames starting at i*video_frame_rate
                transformed_video = transformed_frame_all[i * video_frame_rate:(i * video_frame_rate + video_length_clip)]
            else: # Copy all rest available frames. Pads remaining frames by repeating the last available frame.
                transformed_video[:(video_length - i * video_frame_rate)] = transformed_frame_all[i * video_frame_rate:]
                for j in range((video_length - i * video_frame_rate), video_length_clip):
                    transformed_video[j] = transformed_video[video_length - i * video_frame_rate - 1]
            transformed_video_all.append(transformed_video)

        if video_clip < self.min_video_seconds:
            for i in range(video_clip, self.min_video_seconds):
                transformed_video_all.append(transformed_video_all[video_clip - 1])

        return transformed_video_all, video_name

    def __getitem__(self, index):
        """
        Returns:
            spatial_features (torch.Tensor): Shape [v_len_second, 3, 448, 448]
                `v_len_second` is total seconds of the video (though 2 for toy dataset) with minium 8 (i.e. min_video_seconds).
            motion_features (List[torch.Tensor]): List of motion feature tensors.
                Each tensor has shape [32, 3, 224, 224].
                Len(List) is total seconds of the video (i.e. v_len_second), with minium 8 (i.e. min_video_seconds).
            video_name (str): Video filename
        """
        video_name = self.video_names[index]
        video_path = os.path.join(self.video_dir, video_name)

        spatial_features, video_name = self.video_processing_spatial(video_path)
        motion_features, video_name = self.video_processing_motion(video_path)
        # print('spatial_features: ', spatial_features.shape) # torch.Size([8, 3, 448, 448]) for toy dataset
        # print('motion_features len: ', len(motion_features)) # 8
        # print('motion_features[0]: ', motion_features[0].shape) # torch.Size([32, 3, 224, 224])

        return spatial_features, motion_features, video_name

__getitem__(index)

Returns:

Name Type Description
spatial_features Tensor

Shape [v_len_second, 3, 448, 448] v_len_second is total seconds of the video (though 2 for toy dataset) with minium 8 (i.e. min_video_seconds).

motion_features List[Tensor]

List of motion feature tensors. Each tensor has shape [32, 3, 224, 224]. Len(List) is total seconds of the video (i.e. v_len_second), with minium 8 (i.e. min_video_seconds).

video_name str

Video filename

Source code in aigve/datasets/simplevqa_dataset.py
def __getitem__(self, index):
    """
    Returns:
        spatial_features (torch.Tensor): Shape [v_len_second, 3, 448, 448]
            `v_len_second` is total seconds of the video (though 2 for toy dataset) with minium 8 (i.e. min_video_seconds).
        motion_features (List[torch.Tensor]): List of motion feature tensors.
            Each tensor has shape [32, 3, 224, 224].
            Len(List) is total seconds of the video (i.e. v_len_second), with minium 8 (i.e. min_video_seconds).
        video_name (str): Video filename
    """
    video_name = self.video_names[index]
    video_path = os.path.join(self.video_dir, video_name)

    spatial_features, video_name = self.video_processing_spatial(video_path)
    motion_features, video_name = self.video_processing_motion(video_path)
    # print('spatial_features: ', spatial_features.shape) # torch.Size([8, 3, 448, 448]) for toy dataset
    # print('motion_features len: ', len(motion_features)) # 8
    # print('motion_features[0]: ', motion_features[0].shape) # torch.Size([32, 3, 224, 224])

    return spatial_features, motion_features, video_name

video_processing_motion(video_path)

Extracts motion-based clips suitable for SlowFast. - Standard input size: It resizes frames to 224 * 224. - Motion-based clips: Processes at leaset 8-second clips, select 32 consecutive frames from each second. Return: transformed_video_all (List[torch.Tensor]): Tensor shape[video_length_clip(32), 3, 224, 224]. Len(List) is total seconds of the video, with minium 8. video_name (str)

Source code in aigve/datasets/simplevqa_dataset.py
def video_processing_motion(self, video_path):
    """
    Extracts motion-based clips suitable for SlowFast.
        - Standard input size: It resizes frames to 224 * 224.
        - Motion-based clips: Processes at leaset 8-second clips, select 32 consecutive frames from each second.
    Return:
        transformed_video_all (List[torch.Tensor]): Tensor shape[video_length_clip(32), 3, 224, 224]. 
            Len(List) is total seconds of the video, with minium 8.
        video_name (str)
    """
    video_capture = cv2.VideoCapture(video_path)
    video_name = os.path.basename(video_path)
    video_length = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
    video_frame_rate = int(round(video_capture.get(cv2.CAP_PROP_FPS)))

    transform = transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225]) # General purpose
    ])
    transformed_frame_all = torch.zeros([video_length, 3, 224, 224])
    video_read_index = 0
    for i in range(video_length): # All frames extraction
        has_frames, frame = video_capture.read()
        if has_frames:
            read_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            read_frame = transform(read_frame)
            transformed_frame_all[video_read_index] = read_frame
            video_read_index += 1

    # Pads remaining frames by repeating the last available frame.
    if video_read_index < video_length: 
        for i in range(video_read_index, video_length):
            transformed_frame_all[i] = transformed_frame_all[video_read_index - 1]

    video_capture.release()

    # Compute the number of total seconds of the video
    video_clip = int(video_length/video_frame_rate)
    # print('video_clip (s): ', video_clip)
    video_length_clip = 32
    transformed_video_all = []

    # Extract motion-based clips: select 32 consecutive frames from each second
    for i in range(video_clip):
        transformed_video = torch.zeros([video_length_clip, 3, 224, 224])
        if (i * video_frame_rate + video_length_clip) <= video_length: # if the clip can be fully extracted, select 32 consecutive frames starting at i*video_frame_rate
            transformed_video = transformed_frame_all[i * video_frame_rate:(i * video_frame_rate + video_length_clip)]
        else: # Copy all rest available frames. Pads remaining frames by repeating the last available frame.
            transformed_video[:(video_length - i * video_frame_rate)] = transformed_frame_all[i * video_frame_rate:]
            for j in range((video_length - i * video_frame_rate), video_length_clip):
                transformed_video[j] = transformed_video[video_length - i * video_frame_rate - 1]
        transformed_video_all.append(transformed_video)

    if video_clip < self.min_video_seconds:
        for i in range(video_clip, self.min_video_seconds):
            transformed_video_all.append(transformed_video_all[video_clip - 1])

    return transformed_video_all, video_name

video_processing_spatial(video_path)

Extracts spatial frames with proper resizing and normalization. - Key frame extraction: It selects 1 frame per second. - Standard input size: It resizes frames to 448 * 448 (after an initial resize to 520). Return: transformed_video (torch.Tensor): shape[video_length_read, 3, 448, 448]. video_length_read is total seconds of the video (though 2 for toy dataset) with minium 8 (i.e. min_video_seconds). video_name (str)

Source code in aigve/datasets/simplevqa_dataset.py
def video_processing_spatial(self, video_path):
    """
    Extracts spatial frames with proper resizing and normalization.
        - Key frame extraction: It selects 1 frame per second.
        - Standard input size: It resizes frames to 448 * 448 (after an initial resize to 520).
    Return:
        transformed_video (torch.Tensor): shape[video_length_read, 3, 448, 448]. 
            `video_length_read` is total seconds of the video (though 2 for toy dataset) with minium 8 (i.e. min_video_seconds).
        video_name (str)
    """
    video_capture = cv2.VideoCapture(video_path)
    video_name = os.path.basename(video_path)
    video_length = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
    video_frame_rate = int(round(video_capture.get(cv2.CAP_PROP_FPS)))

    # Compute the number of total seconds of the video
    video_length_read = int(video_length/video_frame_rate) # math.ceil()
    # print('video_length_read (s): ', video_length_read)
    transformations = transforms.Compose([
        transforms.Resize(520),
        transforms.CenterCrop(448),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standard ImageNet normalization
    ])
    transformed_video = torch.zeros([max(video_length_read, self.min_video_seconds), 3, 448, 448])

    video_read_index = 0
    frame_idx = 0
    for i in range(video_length):
        has_frames, frame = video_capture.read()
        if has_frames:
            # Key frames extraction
            if (video_read_index < video_length_read) and (frame_idx % video_frame_rate == 0):
                read_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                read_frame = transformations(read_frame)
                transformed_video[video_read_index] = read_frame
                video_read_index += 1
            frame_idx += 1

    # Pads remaining frames by repeating the last available frame.
    if video_read_index < self.min_video_seconds:
        for i in range(video_read_index, self.min_video_seconds):
            transformed_video[i] = transformed_video[video_read_index - 1]

    video_capture.release()
    return transformed_video, video_name

ToyDataset

Bases: BaseDataset

ToyDataset for testing.

Parameters:

Name Type Description Default
data_root str

Root directory for data.

None
ann_file str

Annotation file path.

''
metainfo dict

Metadata information.

None
data_prefix dict

Prefix paths for different modalities.

None
pipeline List[Union[Callable, dict]]

Data transformation pipeline.

[]
modality dict

Specifies which modalities are used (video, text, image).

dict(use_video=True, use_text=True, use_image=False)
image_frame int

Number of frames for images.

None
Source code in aigve/datasets/toy_dataset.py
@DATASETS.register_module()
class ToyDataset(BaseDataset):
    """ToyDataset for testing.

    Args:
        data_root (str, optional): Root directory for data.
        ann_file (str): Annotation file path.
        metainfo (dict, optional): Metadata information.
        data_prefix (dict): Prefix paths for different modalities.
        pipeline (List[Union[Callable, dict]]): Data transformation pipeline.
        modality (dict): Specifies which modalities are used (video, text, image).
        image_frame (int, optional): Number of frames for images.
    """

    def __init__(self,
                 data_root: Optional[str] = None,
                 ann_file: str = '',
                 metainfo: Optional[dict] = None,
                 data_prefix: dict = None,
                 pipeline: List[Union[Callable, dict]] = [],
                 modality: dict = dict(use_video=True, use_text=True, use_image=False),
                 image_frame: int = None,
                 **kwargs) -> None:
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
            metainfo=metainfo,
            data_prefix=data_prefix,
            pipeline=pipeline,
            **kwargs
        )
        self.modality = modality
        self.image_frame = image_frame
        assert self.modality['use_video'] or self.modality['use_text'], (
            'Please specify the `modality` (`use_video` '
            f', `use_text`) for {self.__class__.__name__}')

    def parse_data_info(self, raw_data_info: dict) -> dict:
        """Parse raw data info."""
        info = {}
        info['img_frame'] = None
        if self.modality['use_text']:
            info['prompt_gt'] = osp.join(self.data_prefix.get('video', ''), 
                                         raw_data_info['prompt_gt'])

        if self.modality['use_video'] or self.modality['use_image']:
            info['video_path_pd'] = osp.join(self.data_prefix.get('video', ''), 
                                     raw_data_info['video_path_pd'])
            if self.modality['use_image']:
                info['img_frame'] = self.image_frame

        return info

parse_data_info(raw_data_info)

Parse raw data info.

Source code in aigve/datasets/toy_dataset.py
def parse_data_info(self, raw_data_info: dict) -> dict:
    """Parse raw data info."""
    info = {}
    info['img_frame'] = None
    if self.modality['use_text']:
        info['prompt_gt'] = osp.join(self.data_prefix.get('video', ''), 
                                     raw_data_info['prompt_gt'])

    if self.modality['use_video'] or self.modality['use_image']:
        info['video_path_pd'] = osp.join(self.data_prefix.get('video', ''), 
                                 raw_data_info['video_path_pd'])
        if self.modality['use_image']:
            info['img_frame'] = self.image_frame

    return info

VideoPhyDataset

Bases: Dataset

Source code in aigve/datasets/videophy_dataset.py
@DATASETS.register_module()
class VideoPhyDataset(Dataset):
    def __init__(self, data_path, video_root_path, hf_token, tokenizer=None, processor=None, max_length=2048, media_tokens=['<image>', '<|video|>'], hf_checkpoint='videophysics/videocon_physics'):
        """
        Args:
            data_path (str): Path to the data folder, it should be a json file
            tokenizer (Tokenizer): Tokenizer object
            processor (Processor): Processor object
            max_length (int): Maximum length of the input sequence
            media_tokens (list): List of media tokens
        """
        self.dataset = json.load(open(data_path))
        self.video_root_path = video_root_path

        self.hf_token = hf_token
        self.hf_checkpoint = hf_checkpoint
        self.max_length = max_length
        self.media_tokens = {k: -int(i + 1) for i, k in enumerate(media_tokens)}
        self.media_lengths = {'<image>': 1 + 64, '<|video|>': 1 + 64}
        self.bucket = {}


        # initialize tokenizer
        if tokenizer is not None:
            self.tokenizer = tokenizer
        else:
            self.tokenizer = LlamaTokenizer.from_pretrained(self.hf_checkpoint, token=self.hf_token)

        MplugOwlImageProcessor, MplugOwlProcessor = lazy_import_mplug_owl()
        self.image_processor = MplugOwlImageProcessor.from_pretrained(self.hf_checkpoint)
        # initialize processor
        if processor is not None:
            self.processor = processor
        else:
            self.processor = MplugOwlProcessor(self.image_processor, self.tokenizer)

    def __len__(self) -> int:
        """
        Returns:
            int: Length of the dataset
        """
        return self.dataset['metainfo']['length']

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the dataset
        Returns:
            dict: Dictionary containing the video, text, video path and caption
        """
        data = self.dataset['dataset_list'][idx]
        videopath = os.path.join(self.video_root_path, data['video_path_pd'])
        caption = data['prompt_gt']
        # video_input = self.processor(videos=[videopath], num_frames=16, return_tensors='pt') # video_pixel_values
        video_input = self.processor(videos=[videopath], num_frames=32, return_tensors='pt')  # video_pixel_values
        text_input = self._extract_text_token_from_conversation(caption, self.max_length, idx)
        item = {'video': video_input, 'text': text_input, 'videopath': videopath, 'caption': caption}
        return item

    def _extract_text_token_from_conversation(self, data, max_length, index):
        """
        Extracts the text tokens from the conversation
        Args:
            data (str): Conversation
            max_length (int): Maximum length of the input sequence
            index (int): Index of the dataset
        """
        # output enc_chunk
        enc_chunk = []

        if self.tokenizer.bos_token_id > 0:
            prompt_chunk = [self.tokenizer.bos_token_id]
        else:
            prompt_chunk = []

        # conversation = data["completion"]
        conversation = data

        # For Text only data
        if all([media_token not in conversation for media_token in self.media_tokens.keys()]):
            pattern = '|'.join(map(re.escape, ['AI: ', '\nHuman: ']))
            chunk_strs = re.split(f'({pattern})', conversation)
            prompt_length = -1
            stop_flag = False
            for idx, chunk_str in enumerate(chunk_strs):
                if idx == 0:
                    enc_chunk = prompt_chunk + \
                                self.tokenizer(chunk_str, add_special_tokens=False)[
                                    'input_ids']
                    enc_length = len(enc_chunk)
                    label_chunk = [0] * enc_length
                else:
                    if chunk_strs[idx - 1] == 'AI: ':
                        curr_chunk = self.tokenizer(
                            chunk_str, add_special_tokens=False)['input_ids']
                        if enc_length + len(curr_chunk) >= max_length:
                            curr_chunk = curr_chunk[:max_length - enc_length]
                            stop_flag = True
                        curr_chunk += [self.tokenizer.eos_token_id]
                        enc_length += len(curr_chunk)
                        enc_chunk += curr_chunk
                        label_chunk += [1] * len(curr_chunk)
                    else:
                        curr_chunk = self.tokenizer(
                            chunk_str, add_special_tokens=False)['input_ids']
                        if enc_length + len(curr_chunk) >= max_length + 1:
                            curr_chunk = curr_chunk[:max_length + 1 - enc_length]
                            stop_flag = True
                        enc_length += len(curr_chunk)
                        enc_chunk += curr_chunk
                        label_chunk += [0] * len(curr_chunk)
                    if stop_flag:
                        break

        # For Image-Text Data
        else:
            enc_length = 0
            prompt_length = -2
            pattern = '|'.join(
                map(re.escape, list(self.media_tokens.keys()) + ['AI: ', '\nHuman: ']))
            chunk_strs = re.split(f'({pattern})', conversation)
            chunk_strs = [x for x in chunk_strs if len(x) > 0]
            for idx, chunk_str in enumerate(chunk_strs):
                if enc_length >= max_length + 1:
                    break

                if idx == 0:
                    enc_chunk = prompt_chunk + \
                                self.tokenizer(chunk_str, add_special_tokens=False)[
                                    'input_ids']
                    enc_length = len(enc_chunk)
                    label_chunk = [0] * enc_length
                else:
                    if chunk_str in self.media_tokens:
                        # [CLS] + 256 + [EOS]
                        if enc_length + self.media_lengths[chunk_str] > max_length + 1:
                            break
                        else:
                            enc_chunk += [self.media_tokens[chunk_str]
                                          ] * self.media_lengths[chunk_str]
                            enc_length += self.media_lengths[chunk_str]
                            label_chunk += [0] * self.media_lengths[chunk_str]
                    else:

                        if chunk_strs[idx - 1] == 'AI: ':
                            curr_chunk = self.tokenizer(
                                chunk_str, add_special_tokens=False)['input_ids']
                            if enc_length + len(curr_chunk) >= max_length:
                                curr_chunk = curr_chunk[:max_length - enc_length]
                            curr_chunk += [self.tokenizer.eos_token_id]
                            enc_length += len(curr_chunk)
                            enc_chunk += curr_chunk
                            label_chunk += [1] * len(curr_chunk)
                        else:
                            curr_chunk = self.tokenizer(
                                chunk_str, add_special_tokens=False)['input_ids']
                            if enc_length + len(curr_chunk) >= max_length + 1:
                                curr_chunk = curr_chunk[:max_length +
                                                         1 - enc_length]
                            enc_length += len(curr_chunk)
                            enc_chunk += curr_chunk
                            label_chunk += [0] * len(curr_chunk)

        if enc_length < max_length + 1:
            padding_chunk = [self.tokenizer.pad_token_id] * \
                            (max_length + 1 - enc_length)
            padding_length = len(padding_chunk)
            label_chunk += [0] * (max_length + 1 - enc_length)
            enc_chunk = enc_chunk + padding_chunk
        else:
            padding_length = 0

        assert enc_length + padding_length == max_length + \
               1, (index, prompt_length, enc_length,
                   padding_length, max_length + 1)
        assert len(label_chunk) == max_length + \
               1, (len(label_chunk), max_length + 1)
        non_padding_mask = [1 if i < enc_length -
                                 1 else 0 for i in range(max_length)]

        enc_chunk = torch.tensor(enc_chunk).long()
        non_padding_mask = torch.tensor(non_padding_mask).long()
        prompt_mask = torch.tensor(label_chunk)[1:].long()
        prompt_length = torch.tensor([prompt_length]).long()

        # Create loss mask
        if all([media_token not in conversation for media_token in self.media_tokens.keys()]):
            non_media_mask = torch.ones_like(non_padding_mask).long()
        else:
            tmp_enc_chunk = enc_chunk.clone()
            tmp_enc_chunk[tmp_enc_chunk >= 0] = 1
            tmp_enc_chunk[tmp_enc_chunk < 0] = 0
            non_media_mask = torch.tensor(tmp_enc_chunk).long()
            non_media_mask = non_media_mask[1:].long()
        return {'input_ids': enc_chunk, "prompt_length": prompt_length, 'seq_length': enc_length,
                "non_padding_mask": non_padding_mask, 'non_media_mask': non_media_mask, 'prompt_mask': prompt_mask}

__getitem__(idx)

Parameters:

Name Type Description Default
idx int

Index of the dataset

required

Returns: dict: Dictionary containing the video, text, video path and caption

Source code in aigve/datasets/videophy_dataset.py
def __getitem__(self, idx):
    """
    Args:
        idx (int): Index of the dataset
    Returns:
        dict: Dictionary containing the video, text, video path and caption
    """
    data = self.dataset['dataset_list'][idx]
    videopath = os.path.join(self.video_root_path, data['video_path_pd'])
    caption = data['prompt_gt']
    # video_input = self.processor(videos=[videopath], num_frames=16, return_tensors='pt') # video_pixel_values
    video_input = self.processor(videos=[videopath], num_frames=32, return_tensors='pt')  # video_pixel_values
    text_input = self._extract_text_token_from_conversation(caption, self.max_length, idx)
    item = {'video': video_input, 'text': text_input, 'videopath': videopath, 'caption': caption}
    return item

__init__(data_path, video_root_path, hf_token, tokenizer=None, processor=None, max_length=2048, media_tokens=['<image>', '<|video|>'], hf_checkpoint='videophysics/videocon_physics')

Parameters:

Name Type Description Default
data_path str

Path to the data folder, it should be a json file

required
tokenizer Tokenizer

Tokenizer object

None
processor Processor

Processor object

None
max_length int

Maximum length of the input sequence

2048
media_tokens list

List of media tokens

['<image>', '<|video|>']
Source code in aigve/datasets/videophy_dataset.py
def __init__(self, data_path, video_root_path, hf_token, tokenizer=None, processor=None, max_length=2048, media_tokens=['<image>', '<|video|>'], hf_checkpoint='videophysics/videocon_physics'):
    """
    Args:
        data_path (str): Path to the data folder, it should be a json file
        tokenizer (Tokenizer): Tokenizer object
        processor (Processor): Processor object
        max_length (int): Maximum length of the input sequence
        media_tokens (list): List of media tokens
    """
    self.dataset = json.load(open(data_path))
    self.video_root_path = video_root_path

    self.hf_token = hf_token
    self.hf_checkpoint = hf_checkpoint
    self.max_length = max_length
    self.media_tokens = {k: -int(i + 1) for i, k in enumerate(media_tokens)}
    self.media_lengths = {'<image>': 1 + 64, '<|video|>': 1 + 64}
    self.bucket = {}


    # initialize tokenizer
    if tokenizer is not None:
        self.tokenizer = tokenizer
    else:
        self.tokenizer = LlamaTokenizer.from_pretrained(self.hf_checkpoint, token=self.hf_token)

    MplugOwlImageProcessor, MplugOwlProcessor = lazy_import_mplug_owl()
    self.image_processor = MplugOwlImageProcessor.from_pretrained(self.hf_checkpoint)
    # initialize processor
    if processor is not None:
        self.processor = processor
    else:
        self.processor = MplugOwlProcessor(self.image_processor, self.tokenizer)

__len__()

Returns:

Name Type Description
int int

Length of the dataset

Source code in aigve/datasets/videophy_dataset.py
def __len__(self) -> int:
    """
    Returns:
        int: Length of the dataset
    """
    return self.dataset['metainfo']['length']

VideoScoreDataset

Bases: BaseDataset

Source code in aigve/datasets/videoscore_dataset.py
@DATASETS.register_module()
class VideoScoreDataset(BaseDataset):
    def __init__(self, ann_file='', metainfo=None, data_root='', data_prefix={'video_path_pd': ''}, filter_cfg=None, indices=None,
                 serialize_data=True, pipeline=[], test_mode=False, lazy_init=False, max_refetch=1000, model_name = None, regression_query_prompt: str = None,
                max_num_frames: int = None):
        """
        Args:
            ann_file (str): annotation file path
            metainfo (dict): meta information about the dataset
            data_root (str): the root path of the data
            data_prefix (dict): the prefix of the data, for example, the prefix of the image path
            filter_cfg (dict): the filter configuration
            indices (list): the indices of the data
            serialize_data (bool): whether to serialize the data
            pipeline (list): the pipeline of the data
            test_mode (bool): whether in test mode
            lazy_init (bool): whether to lazy initialize the dataset
            max_refetch (int): the maximum number of refetching data
            model_name (str): the name of the model
            regression_query_prompt (str): the prompt for the regression query
            max_num_frames (int): the maximum number of frames
        """
        super(VideoScoreDataset, self).__init__(ann_file, metainfo, data_root, data_prefix, filter_cfg, indices, serialize_data, pipeline, test_mode, lazy_init, max_refetch)
        if model_name is None:
            self.model_name = 'TIGER-Lab/VideoScore-v1.1'
        else:
            self.model_name = model_name

        self.processor = AutoProcessor.from_pretrained(self.model_name,torch_dtype=torch.bfloat16)

        if regression_query_prompt is not None:
            self.regression_query_prompt = regression_query_prompt
        else:
            self.regression_query_prompt = '''
                Suppose you are an expert in judging and evaluating the quality of AI-generated videos,
                please watch the following frames of a given video and see the text prompt for generating the video,
                then give scores from 5 different dimensions:
                (1) visual quality: the quality of the video in terms of clearness, resolution, brightness, and color
                (2) temporal consistency, both the consistency of objects or humans and the smoothness of motion or movements
                (3) dynamic degree, the degree of dynamic changes
                (4) text-to-video alignment, the alignment between the text prompt and the video content
                (5) factual consistency, the consistency of the video content with the common-sense and factual knowledge
                for each dimension, output a float number from 1.0 to 4.0,
                the higher the number is, the better the video performs in that sub-score, 
                the lowest 1.0 means Bad, the highest 4.0 means Perfect/Real (the video is like a real video)
                Here is an output example:
                visual quality: 3.2
                temporal consistency: 2.7
                dynamic degree: 4.0
                text-to-video alignment: 2.3
                factual consistency: 1.8
                For this video, the text prompt is "{text_prompt}",
                all the frames of video are as follows:
            '''
        if max_num_frames is not None:
            self.max_num_frames = max_num_frames
        else:
            self.max_num_frames = 48

    def __len__(self) -> int:
        """
        Returns:
            int: the length of the dataset
        """
        return self.metainfo['length']


    def __getitem__(self, idx):
        """
        Args:
            idx (int): the index of the data
        """
        anno_info = self.get_data_info(idx)
        video_path = os.path.join(self.data_root, anno_info['video_path_pd'])

        container = av.open(video_path)

        total_frames = container.streams.video[0].frames
        if total_frames > self.max_num_frames:
            indices = np.arange(0, total_frames, total_frames / self.max_num_frames).astype(int)
        else:
            indices = np.arange(total_frames)

        frames = [Image.fromarray(x) for x in _read_video_pyav(container, indices)]
        eval_prompt = self.regression_query_prompt.format(text_prompt=anno_info['prompt_gt'])
        num_image_token = eval_prompt.count("<image>")
        if num_image_token < len(frames):
            eval_prompt += "<image> " * (len(frames) - num_image_token)

        flatten_images = []
        for x in [frames]:
            if isinstance(x, list):
                flatten_images.extend(x)
            else:
                flatten_images.append(x)
        flatten_images = [Image.open(x) if isinstance(x, str) else x for x in flatten_images]
        inputs = self.processor(text=eval_prompt, images=flatten_images, return_tensors="pt")
        return inputs

__getitem__(idx)

Parameters:

Name Type Description Default
idx int

the index of the data

required
Source code in aigve/datasets/videoscore_dataset.py
def __getitem__(self, idx):
    """
    Args:
        idx (int): the index of the data
    """
    anno_info = self.get_data_info(idx)
    video_path = os.path.join(self.data_root, anno_info['video_path_pd'])

    container = av.open(video_path)

    total_frames = container.streams.video[0].frames
    if total_frames > self.max_num_frames:
        indices = np.arange(0, total_frames, total_frames / self.max_num_frames).astype(int)
    else:
        indices = np.arange(total_frames)

    frames = [Image.fromarray(x) for x in _read_video_pyav(container, indices)]
    eval_prompt = self.regression_query_prompt.format(text_prompt=anno_info['prompt_gt'])
    num_image_token = eval_prompt.count("<image>")
    if num_image_token < len(frames):
        eval_prompt += "<image> " * (len(frames) - num_image_token)

    flatten_images = []
    for x in [frames]:
        if isinstance(x, list):
            flatten_images.extend(x)
        else:
            flatten_images.append(x)
    flatten_images = [Image.open(x) if isinstance(x, str) else x for x in flatten_images]
    inputs = self.processor(text=eval_prompt, images=flatten_images, return_tensors="pt")
    return inputs

__init__(ann_file='', metainfo=None, data_root='', data_prefix={'video_path_pd': ''}, filter_cfg=None, indices=None, serialize_data=True, pipeline=[], test_mode=False, lazy_init=False, max_refetch=1000, model_name=None, regression_query_prompt=None, max_num_frames=None)

Parameters:

Name Type Description Default
ann_file str

annotation file path

''
metainfo dict

meta information about the dataset

None
data_root str

the root path of the data

''
data_prefix dict

the prefix of the data, for example, the prefix of the image path

{'video_path_pd': ''}
filter_cfg dict

the filter configuration

None
indices list

the indices of the data

None
serialize_data bool

whether to serialize the data

True
pipeline list

the pipeline of the data

[]
test_mode bool

whether in test mode

False
lazy_init bool

whether to lazy initialize the dataset

False
max_refetch int

the maximum number of refetching data

1000
model_name str

the name of the model

None
regression_query_prompt str

the prompt for the regression query

None
max_num_frames int

the maximum number of frames

None
Source code in aigve/datasets/videoscore_dataset.py
def __init__(self, ann_file='', metainfo=None, data_root='', data_prefix={'video_path_pd': ''}, filter_cfg=None, indices=None,
             serialize_data=True, pipeline=[], test_mode=False, lazy_init=False, max_refetch=1000, model_name = None, regression_query_prompt: str = None,
            max_num_frames: int = None):
    """
    Args:
        ann_file (str): annotation file path
        metainfo (dict): meta information about the dataset
        data_root (str): the root path of the data
        data_prefix (dict): the prefix of the data, for example, the prefix of the image path
        filter_cfg (dict): the filter configuration
        indices (list): the indices of the data
        serialize_data (bool): whether to serialize the data
        pipeline (list): the pipeline of the data
        test_mode (bool): whether in test mode
        lazy_init (bool): whether to lazy initialize the dataset
        max_refetch (int): the maximum number of refetching data
        model_name (str): the name of the model
        regression_query_prompt (str): the prompt for the regression query
        max_num_frames (int): the maximum number of frames
    """
    super(VideoScoreDataset, self).__init__(ann_file, metainfo, data_root, data_prefix, filter_cfg, indices, serialize_data, pipeline, test_mode, lazy_init, max_refetch)
    if model_name is None:
        self.model_name = 'TIGER-Lab/VideoScore-v1.1'
    else:
        self.model_name = model_name

    self.processor = AutoProcessor.from_pretrained(self.model_name,torch_dtype=torch.bfloat16)

    if regression_query_prompt is not None:
        self.regression_query_prompt = regression_query_prompt
    else:
        self.regression_query_prompt = '''
            Suppose you are an expert in judging and evaluating the quality of AI-generated videos,
            please watch the following frames of a given video and see the text prompt for generating the video,
            then give scores from 5 different dimensions:
            (1) visual quality: the quality of the video in terms of clearness, resolution, brightness, and color
            (2) temporal consistency, both the consistency of objects or humans and the smoothness of motion or movements
            (3) dynamic degree, the degree of dynamic changes
            (4) text-to-video alignment, the alignment between the text prompt and the video content
            (5) factual consistency, the consistency of the video content with the common-sense and factual knowledge
            for each dimension, output a float number from 1.0 to 4.0,
            the higher the number is, the better the video performs in that sub-score, 
            the lowest 1.0 means Bad, the highest 4.0 means Perfect/Real (the video is like a real video)
            Here is an output example:
            visual quality: 3.2
            temporal consistency: 2.7
            dynamic degree: 4.0
            text-to-video alignment: 2.3
            factual consistency: 1.8
            For this video, the text prompt is "{text_prompt}",
            all the frames of video are as follows:
        '''
    if max_num_frames is not None:
        self.max_num_frames = max_num_frames
    else:
        self.max_num_frames = 48

__len__()

Returns:

Name Type Description
int int

the length of the dataset

Source code in aigve/datasets/videoscore_dataset.py
def __len__(self) -> int:
    """
    Returns:
        int: the length of the dataset
    """
    return self.metainfo['length']