Skip to content

Compatibility with timm augmentation? #195

@ardasahiner

Description

@ardasahiner

Hi,

I was attempting to use FFCV with timm, using the fact thattorch.nn.Modules should be compatible with the pipelines argument of FFCV's Loader. However, I am getting some strange errors and would like some clarification on what is going wrong here.

Please see my simple reproducible implementation below. I use CIFAR100 images and use timm's create_transform function. While each transform is not an instance of nn.Module, I attempted to wrap it in a simple module with the CustomClass. However, I get the following issue as documented below.

Would you have any suggestions what is causing this issue, or any ideas for a simpler integration with timm? Any help is appreciated.

Error:

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
Untyped global name 'self': Cannot determine Numba type of <class 'ffcv.transforms.module.ModuleWrapper'>

File "../anaconda3/envs/ffcv/lib/python3.9/site-packages/ffcv/transforms/module.py", line 25:
        def apply_module(inp, _):
            res = self.module(inp)
            ^

During: resolving callee type: type(CPUDispatcher(<function ModuleWrapper.generate_code.<locals>.apply_module at 0x7fef91639670>))
During: typing of call at  (2)

During: resolving callee type: type(CPUDispatcher(<function ModuleWrapper.generate_code.<locals>.apply_module at 0x7fef91639670>))
During: typing of call at  (2)

Implementation:

import torch
import numpy as np
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform

from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, \
    RandomResizedCropRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, Cutout, NormalizeImage, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze

class CustomClass(torch.nn.Module):
    def __init__(self, transform):
        super().__init__()
        self.transform = transform

    @staticmethod
    def get_params(img, scale, ratio):
        return self.transform.get_params(img, scale, ratio)

    def forward(self, img):
        return self.transform(img)

    def __repr__(self):
        return self.transform.__repr__()

is_train = True
imnet_mean, imnet_std = np.array(IMAGENET_DEFAULT_MEAN)*256, np.array(IMAGENET_DEFAULT_STD)*256

paths = {
    'train': 'cifar100_train.beton',
    'test': 'cifar100_test.beton'
}

to_module = create_transform(
                input_size=224,
                is_training=True,
                color_jitter=0.4,
                auto_augment='rand-m9-mstd0.5-inc1',
                interpolation='bicubic',
                re_prob=0.25,
                re_mode='pixel',
                re_count=False,
                mean = imnet_mean,
                std = imnet_std,
            )

module_list = []
for t in to_module.transforms:
    if isinstance(t, torch.nn.Module):
        module_list.append(t)
    else:
        t_new = CustomClass(t)
        module_list.append(t_new)

transform = torch.nn.Sequential(*module_list)

label_pipeline = [IntDecoder(), ToTensor(), Squeeze()]
image_pipeline = [SimpleRGBImageDecoder(), transform]

ordering =(OrderOption.QUASI_RANDOM) if is_train else OrderOption.SEQUENTIAL
dataset = Loader(paths['train'] if is_train else paths['test'], batch_size=10, num_workers=2,
                       order=ordering, drop_last=(is_train), os_cache=True, distributed=False,
                       pipelines={'image': image_pipeline, 'label': label_pipeline})

for i, (image, label) in enumerate(dataset):
    if i == 1:
        break
    print('loaded one image')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions