Skip to content

Commit

Permalink
better but still 0.2 lower than paper
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung committed Nov 22, 2022
1 parent 78abd1b commit b5a5e24
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 81 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ main_mfm.py
main_finetune.py
mfm/__pycache__/
.nfs*
*tar
try.py
56 changes: 31 additions & 25 deletions dist_train.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@

export OMP_NUM_THREADS=4
export MKL_NUM_THREADS=4

export CUDA_VISIBLE_DEVICES=0,1,2,3
NGPUS=4
LR=0.0006
# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# NGPUS=8
# LR=0.0012

# rm ./res_pretrain/model_300.pth

Expand All @@ -16,7 +20,7 @@ torchrun --nproc_per_node=$NGPUS train_mfm.py \
--epochs 300 \
--opt adamw \
--batch-size 256 \
--lr 0.0006 \
--lr $LR \
--wd 0.05 \
--lr-scheduler cosineannealinglr \
--lr-warmup-epochs 20 \
Expand All @@ -32,6 +36,8 @@ torchrun --nproc_per_node=$NGPUS train_mfm.py \
# --resume res_pretrain/model_80.pth \


sleep 10

# finetune 100ep
export CUDA_VISIBLE_DEVICES=0,1,2,3
NGPUS=4
Expand All @@ -42,27 +48,27 @@ LR=0.006

## bs=2048, lr=1.2e-2

# torchrun --nproc_per_node=$NGPUS train_finetune.py \
# --data-path ./imagenet/ \
# --model resnet50 \
# --batch-size 256 \
# --epochs 100 \
# --opt adamw \
# --lr $LR \
# --wd 0.02 \
# --label-smoothing 0.1 \
# --mixup-alpha 0.1 \
# --cutmix-alpha 1.0 \
# --lr-scheduler cosineannealinglr \
# --lr-warmup-epochs 5 \
# --lr-warmup-method linear \
# --output-dir ./res_finetune \
# --auto-augment ra_6_10 \
# --weights ./res_pretrain/model_300.pth \
# --amp \
# --val-resize-size 236 \
# --train-crop-size 160
#
#
# # --lr 0.009 \
# # --lr 0.012 \
torchrun --nproc_per_node=$NGPUS train_finetune.py \
--data-path ./imagenet/ \
--model resnet50 \
--batch-size 256 \
--epochs 100 \
--opt adamw \
--lr $LR \
--wd 0.02 \
--label-smoothing 0.1 \
--mixup-alpha 0.1 \
--cutmix-alpha 1.0 \
--lr-scheduler cosineannealinglr \
--lr-warmup-epochs 5 \
--lr-warmup-method linear \
--output-dir ./res_finetune \
--auto-augment ra_6_10 \
--weights ./res_pretrain/model_300.pth \
--amp \
--val-resize-size 236 \
--train-crop-size 160


# --lr 0.009 \
# --lr 0.012 \
4 changes: 2 additions & 2 deletions mfm/dali_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def create_dali_pipeline(data_dir, crop, size, shard_id, num_shards, dali_cpu=Fa
interp_type=types.INTERP_TRIANGULAR)
mirror = False

mean = [0.485, 0.456 ,0.406]
std = [0.229, 0.224 ,0.225]
mean = [0.4, 0.4 ,0.4]
std = [0.2, 0.2 ,0.2]
images = fn.crop_mirror_normalize(images.gpu(),
dtype=types.FLOAT,
output_layout="CHW",
Expand Down
15 changes: 12 additions & 3 deletions mfm/fft_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@ def __call__(self, x):
'''
input is pytorch tensor of nchw
'''
x_fft = torch.fft.fft2(x)
x_fft = torch.fft.fft2(x, norm='ortho')
x_fft_shift = torch.fft.fftshift(x_fft)
fft_mask = self.gen_mask_circle(x_fft_shift)
x_fft_filter = fft_mask * x_fft_shift
x_fft_ishift = torch.fft.ifftshift(x_fft_filter)
x_ifft = torch.fft.ifft2(x_fft_ishift).abs()
return x_ifft, x_fft_shift, fft_mask
x_ifft = torch.fft.ifft2(x_fft_ishift, norm='ortho').abs()
# return x_ifft, x_fft_shift, fft_mask

## one channel target replace x_fft_shift
gray = x[:, 0:1, ...] * 0.229 + x[:, 1:2, ...] * 0.587 + x[:, 2:3, ...] * 0.114
gray_fft = torch.fft.fft2(gray, norm='ortho')
gray_fft_shift = torch.fft.fftshift(gray_fft)
return x_ifft, gray_fft_shift, fft_mask


@torch.no_grad()
Expand Down Expand Up @@ -51,10 +57,13 @@ def gen_mask_circle(self, x):
'''
device = x.device
batchsize = x.size(0)
n_chan = x.size(1)
size = tuple(x.size()[-2:])
if self.mask_table is None or size != self.mask_size:
self.gen_mask_table(x)
inds = torch.randint(0, 2, (batchsize, ), device=device)
batch_mask = self.mask_table[inds] # n1hw
# inds = torch.randint(0, 2, (batchsize, n_chan), device=device)
# batch_mask = self.mask_table.squeeze(1)[inds] # nchw
return batch_mask

166 changes: 126 additions & 40 deletions mfm/focal_frequency_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, gamma=1.):
# # loss = diff.abs().pow(self.gamma)
# # loss = (diff.real.pow(2.) + diff.imag.pow(2.)).pow(self.gamma/2)
# if not mask is None:
# mask = mask.expand_as(preds)
# mask = mask.expand_as(preds) ## 这里别忘了expand_as,因为三个通道的结果也要mean
# loss = loss * (1. - mask)
# loss = loss.sum(dim=(1,2,3))
# n_pixel = (1. - mask).sum(dim=(1,2,3))
Expand Down Expand Up @@ -56,54 +56,140 @@ def __init__(self, gamma=1.):
#
# return loss

## 刚才跑的是这个
# def forward(self, preds, target, mask=None):
# '''
# preds is nchw real tensor
# target is nchw complex tensor
# mask is n1hw real tensor, we should use (1-mask) if we reconstruct masked portion
# '''
# preds = preds.float()
# p_fft = torch.fft.fft2(preds)
# p_fft_shift = torch.fft.fftshift(p_fft)
#
# target = target.detach()
# # replace = (target + p_fft_shift) / 2.
# # p_fft_shift = replace * mask + p_fft_shift * (1. - mask)
#
# mask = (1 - mask).expand_as(preds).bool()
# p_fft_shift = p_fft_shift[mask]
# target = target[mask]
#
# # avg spectrum
# # p_fft_shift = p_fft_shift.mean(dim=0, keepdim=True)
# # target = target.mean(dim=0, keepdim=True)
#
# ## pow(2) first, then split and sum
# # diff = (target - p_fft_shift).pow(2.)
# # diff = diff.real + diff.imag
#
# ## split first, then pow(2) and sum
# diff = target - p_fft_shift
# diff = torch.stack([diff.real, diff.imag], dim=-1).pow(2.)
# diff = diff[..., 0] + diff[..., 1]
#
# # diff = diff[(1 - mask.expand_as(diff)).bool()]
#
# with torch.no_grad():
# weight = diff.pow(self.gamma)
# # adjust specturm by log
# # weight = torch.log(weight + 1.)
# # use batch-based statistics to compute spectrum weight
# batch_matrix = True
# if batch_matrix:
# weight = weight / weight.max()
# else:
# n, c = weight.size()[:2]
# weight = weight / weight.flatten(1).max(dim=-1, keepdim=True).values
# weight = weight.reshape(n, c, 1, 1)
# # fix bad values
# weight[weight.isnan()] = 0.
# weight = weight.clamp(min=0., max=1.)
#
# loss = diff * weight
# return loss.mean()


# def forward(self, preds, target, mask=None):
# '''
# preds is nchw real tensor
# target is nchw complex tensor
# mask is n1hw real tensor, we should use (1-mask) if we reconstruct masked portion
# '''
# preds = preds.float()
# p_fft = torch.fft.fft2(preds)
# p_fft_shift = torch.fft.fftshift(p_fft)
#
# target = target.detach()
# # replace = (target + p_fft_shift) / 2.
# # p_fft_shift = replace * mask + p_fft_shift * (1. - mask)
#
# # mask = (1 - mask).expand_as(preds).bool()
# mask = mask.expand_as(preds)
# # p_fft_shift = p_fft_shift[mask]
# # target = target[mask]
#
# # avg spectrum
# # p_fft_shift = p_fft_shift.mean(dim=0, keepdim=True)
# # target = target.mean(dim=0, keepdim=True)
#
# ## pow(2) first, then split and sum
# # diff = (target - p_fft_shift).pow(2.)
# # diff = diff.real + diff.imag
#
# ## split first, then pow(2) and sum
# diff = target - p_fft_shift
# diff = torch.stack([diff.real, diff.imag], dim=-1).pow(2.)
# diff = diff[..., 0] + diff[..., 1]
#
# # diff = diff[(1 - mask.expand_as(diff)).bool()]
#
# with torch.no_grad():
# weight = diff.pow(self.gamma)
# # adjust specturm by log
# # weight = torch.log(weight + 1.)
# # use batch-based statistics to compute spectrum weight
# batch_matrix = True
# if batch_matrix:
# weight = weight / weight.max()
# else:
# n, c = weight.size()[:2]
# weight[mask.bool()] = 0.
# weight = weight / weight.flatten(1).max(dim=-1, keepdim=True).values
# weight = weight.reshape(n, c, 1, 1)
# # fix bad values
# weight[weight.isnan()] = 0.
# weight = weight.clamp(min=0., max=1.)
#
# n_pixel = preds.numel() - mask.sum()
#
# loss = diff * weight
# loss = loss.sum() / n_pixel
# return loss


def forward(self, preds, target, mask=None):
'''
preds is nchw real tensor
target is nchw complex tensor
mask is n1hw real tensor, we should use (1-mask) if we reconstruct masked portion
'''
preds = preds.float()
p_fft = torch.fft.fft2(preds)
p_fft = torch.fft.fft2(preds, norm='ortho')
p_fft_shift = torch.fft.fftshift(p_fft)

target = target.detach()
# replace = (target + p_fft_shift) / 2.
# p_fft_shift = replace * mask + p_fft_shift * (1. - mask)

mask = (1 - mask).expand_as(preds).bool()
p_fft_shift = p_fft_shift[mask]
target = target[mask]

# avg spectrum
# p_fft_shift = p_fft_shift.mean(dim=0, keepdim=True)
# target = target.mean(dim=0, keepdim=True)

## pow(2) first, then split and sum
# diff = (target - p_fft_shift).pow(2.)
# diff = diff.real + diff.imag

## split first, then pow(2) and sum
diff = target - p_fft_shift
diff = torch.stack([diff.real, diff.imag], dim=-1).pow(2.)
diff = diff[..., 0] + diff[..., 1]

# diff = diff[(1 - mask.expand_as(diff)).bool()]
# loss = diff.abs()
eps = 1e-7
loss = (diff.real.pow(2.) + diff.imag.pow(2.) + eps).pow(self.gamma/2.)
if not mask is None:
mask = mask.expand_as(preds) ## 这里别忘了expand_as,因为三个通道的结果也要mean
mask = 1. - mask
loss = loss * mask
loss = loss.sum(dim=(2,3))
n_pixel = mask.sum(dim=(2,3))
loss = loss / n_pixel
loss = loss.mean()
return loss

with torch.no_grad():
weight = diff.pow(self.gamma)
# adjust specturm by log
weight = torch.log(weight + 1.)
# use batch-based statistics to compute spectrum weight
batch_matrix = True
if batch_matrix:
weight = weight / weight.max()
else:
n, c = weight.size()[:2]
weight = weight / weight.flatten(1).max(dim=-1, keepdim=True).values
weight = weight.reshape(n, c, 1, 1)
# fix bad values
weight[weight.isnan()] = 0.
weight = weight.clamp(min=0., max=1.)

loss = diff * weight
return loss.mean()
12 changes: 6 additions & 6 deletions mfm/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ def __init__(
self,
*,
crop_size,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
mean=(0.4, 0.4, 0.4),
std=(0.2, 0.2, 0.2),
interpolation=InterpolationMode.BILINEAR,
hflip_prob=0.5,
auto_augment_policy=None,
Expand Down Expand Up @@ -57,8 +57,8 @@ def __init__(
*,
crop_size,
resize_size=256,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
mean=(0.4, 0.4, 0.4),
std=(0.2, 0.2, 0.2),
interpolation=InterpolationMode.BILINEAR,
):

Expand All @@ -82,8 +82,8 @@ def __init__(
self,
*,
crop_size,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
mean=(0.4, 0.4, 0.4),
std=(0.2, 0.2, 0.2),
interpolation=InterpolationMode.BILINEAR,
hflip_prob=0.5,
):
Expand Down
8 changes: 5 additions & 3 deletions mfm/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,16 @@ def _make_layer(
def mfm(self):
self.avgpool = nn.Identity()
self.fc = nn.Identity()
out_chan = self.layer4[-1].expansion * 512
bb_out_chan = self.layer4[-1].expansion * 512
md_out_chan = 1

# self.decoder = nn.Sequential(
# nn.Conv2d(out_chan, 3, 1, 1, 0, bias=True),
# nn.Conv2d(bb_out_chan, md_out_chan, 1, 1, 0, bias=True),
# nn.Upsample(scale_factor=32., mode='bicubic',
# align_corners=False, antialias=True))
# self.decoder = nn.Conv2d(out_chan, 3, 1, 1, 0, bias=True)
self.decoder = nn.Sequential(
nn.Conv2d(out_chan, 3 * 32 * 32, 1, 1, 0, bias=True),
nn.Conv2d(bb_out_chan, md_out_chan * 32 * 32, 1, 1, 0, bias=True),
nn.PixelShuffle(32))
self.forward = self.forward_mfm

Expand Down
Loading

0 comments on commit b5a5e24

Please sign in to comment.