Environment Preparation:
https://github.com/vturrisi/solo-learn (detailed ducoments)

pip install -r requirements

Datasets:
In addition to IMAGENET-1k needs to go to the official website to download, the rest will be automatically downloaded during training. You only need to specify the config file.

Training:
base in  ['byol','nnbyol','mocov2plus','mocov3','dino']
method = ['none','risp','so']
regular_weight in ['1e-4','1e-6'] (refer to our recipe of adding OR)
cmd = 'python3 -u main_pretrain.py  --config-path scripts/pretrain/cifar --config-name {}.yaml ++regular_method={} ++regular_weight={}'.format(base,method,weight)
cmd = 'python3 -u main_pretrain.py  --config-path scripts/pretrain/imagenet --config-name {}.yaml ++regular_method={} ++regular_weight={}'.format(base,method,weight)
At the same time of pre-training, there are linear classifiers that are trained separately.


Transfer Learning:
python3 -u main_linear.py  --config-path scripts/linear/cifar100 --config-name byol.yaml
Specify the pre-training weight and the required data set in config.

Object detection:
Refer to the readme in downstream/object_detection



def l2_reg_ortho_loss_func(mdl,device,weight = 1e-2,method='risp'):
    l2_reg = None
  
    for W in mdl.parameters():
        if W.ndimension() < 2:
                continue
        else:
            cols = W[0].numel()
            rows = W.shape[0]
            

            if method =='risp':
                w1 = W.reshape(-1,cols)
                wt = torch.transpose(w1,0,1).contiguous()
                if (rows > cols):
                        m  = torch.matmul(wt,w1)
                        ident = Variable(torch.eye(cols,cols),requires_grad=True)
                else:
                        m = torch.matmul(w1,wt)
                        ident = Variable(torch.eye(rows,rows), requires_grad=True)

                ident = ident.to(device)
                w_tmp = (m - ident)
                b_k = Variable(torch.rand(w_tmp.shape[1],1))
                b_k = b_k.to(device)

                v1 = torch.matmul(w_tmp, b_k)
                norm1 = torch.norm(v1,2)
                v2 = torch.div(v1,norm1)
                v3 = torch.matmul(w_tmp,v2)

                if l2_reg is None:
                        l2_reg = (torch.norm(v3,2))**2
                else:
                        l2_reg = l2_reg + (torch.norm(v3,2))**2
            elif method =='so':
                w1 = W.reshape(-1,cols)
                wt = torch.transpose(w1,0,1).contiguous()
                if (rows > cols):
                        m  = torch.matmul(wt,w1)
                        ident = Variable(torch.eye(cols,cols),requires_grad=True)
                else:
                        m = torch.matmul(w1,wt)
                        ident = Variable(torch.eye(rows,rows), requires_grad=True)

                ident = ident.to(device)
                w_tmp = (m - ident)

                if l2_reg is None:
                        l2_reg = (torch.norm(w_tmp ,2))**2
                else:
                        l2_reg = l2_reg + (torch.norm(w_tmp,2))**2
                               
        return weight*l2_reg