Skip to content

Add BEiT3 #2489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Add BEiT3 #2489

wants to merge 4 commits into from

Conversation

brianhou0208
Copy link
Contributor

(CVPR 2023) BEiT-3 is a multimodal model. Although it does not stand out on ImageNet, it achieves impressive results in other domains. Leveraging its powerful pretraining data, it can deliver strong performance on downstream tasks.

image

Model Issue & Request

Result(ImageNet)

https://github.com/microsoft/unilm/tree/master/beit3#fine-tuning-on-imagenet-1k-image-classification

Model Weight Acc@1 Acc@5 FLOPs(G) MACs(G) Params(M)
beit3_base_patch16_224 in22k_ft_in1k 85.370 97.640 35.13 15.57 86.66
in22k_indomain_ft_in1k 85.446 97.616
beit3_large_patch16_224 in22k_ft_in1k 87.624 98.332 123.12 61.56 304.57
in22k_indomain_ft_in1k 87.538 98.362
beit3_giant_patch14_224 534.09 267.05 1000.1
beit3_giant_patch14_336 1240.70 620.35 1000.1

Note

The performance reported in the paper is based on the Giant model, and the authors do not plan to release its weights.
microsoft/unilm#1031, microsoft/unilm#1382, microsoft/unilm#1435


test code
from typing import Any, Dict, Union, List
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import timm
from timm.utils.metrics import AverageMeter, accuracy

device = torch.device('mps')
torch.mps.empty_cache()

def auto_unit(x: float, unit: str = '') -> str:
    if x >= 1e9:
        return f"{x / 1e9:.2f}G {unit}"
    elif x >= 1e6:
        return f"{x / 1e6:.2f}M {unit}"
    elif x >= 1e3:
        return f"{x / 1e3:.2f}K {unit}"
    else:
        return f"{x:.2f} {unit}"
 
 
def get_model_info(model: torch.nn.Module, imgsz: Union[int, List[int]] = 224) -> Dict[str, str]:
    """
    Compute model FLOPs, MACs, and Params using torch profiler.

    Args:
        model (nn.Module): The model to calculate for.
        imgsz (int | List[int], optional): Input image size. Defaults to 224.

    Returns:
        dict: Dictionary containing FLOPs, MACs, and Params with auto units.
    """
    p = next(model.parameters())
    if not isinstance(imgsz, list):
        imgsz = [imgsz, imgsz]

    im = torch.empty((1, 3, *imgsz), device=p.device)

    with torch.profiler.profile(with_flops=True) as prof:
        model(im)

    flops = sum(e.flops for e in prof.key_averages())
    macs = flops / 2
    params = sum(p.numel() for p in model.parameters())

    return {
        "FLOPs": auto_unit(flops, ""),
        "MACs": auto_unit(macs, ""),
        "Params": auto_unit(params, ""),
    }


def get_model_acc(model: torch.nn.Module):
    cfg: Dict[str, Any]= model.default_cfg
    _, height, width = cfg['input_size'] if 'test_input_size' not in cfg else cfg['test_input_size']
    imgsz = height if height == width else (height, width)

    interp_mode = {
        "nearest": 0,
        "bilinear": 2,
        "bicubic": 3,
    }

    val_dataset = datasets.ImageFolder(
        '/Users/ryanhou/Downloads/imagenet/val',
        transforms.Compose([
            transforms.Resize(int(imgsz / cfg['crop_pct']), interpolation=interp_mode[cfg['interpolation']]),
            transforms.CenterCrop(imgsz),
            transforms.ToTensor(),
            transforms.Normalize(cfg['mean'], cfg['std'])])
    )
    val_loader = DataLoader(
        val_dataset, batch_size=64, shuffle=False, pin_memory=False, prefetch_factor=4, num_workers=4,
        persistent_workers=True#, pin_memory_device='mps'
    )

    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    model.to(device)
    torch.mps.synchronize()
    with torch.no_grad():
        for images, target in tqdm(val_loader):
            images = images.to(device)
            target = target.to(device)
            output = model(images)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            top1.update(acc1, images.size(0))
            top5.update(acc5, images.size(0))
    torch.mps.synchronize()
    return {"ACC@1": round(top1.avg.item(), 4), "ACC@5": round(top5.avg.item(), 4)}
 
 
if __name__ == "__main__":

    for name in timm.list_models('beit3*', pretrained=True):
        model = timm.create_model(name, pretrained=True).eval()
        result = get_model_acc(model)
        print(name, result)

Reference

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants