
import torch
from torch.utils.checkpoint import checkpoint


chunks = 1000
# 15*64*64, 320
X = torch.randn(15*64*64, 320, device="cuda", dtype=torch.float16)
W = torch.randn(320, 15*64*64, device="cuda", dtype=torch.float16)

def func(x, w):
    return x @ w

# res = func(X, W)  # require large memory

Xs = torch.chunk(X, chunks, dim=1)
Ws = torch.chunk(W, chunks, dim=0)
arr = []
for x, w in zip(Xs, Ws):
    # arr.append(func(x, w)) # also require large memory 
    arr.append(checkpoint(func, x, w))  # checkpoint release the intermediate variables
res = torch.stack(arr, dim=0).sum(dim=0).exp()

print(torch.cuda.memory_allocated())

# import torch
# from torch.utils.checkpoint import checkpoint

# n = 100
# m = 2000
# l = 150
# chunks = 20

# X = torch.randn(n, m, device="cuda").requires_grad_() * 0.01
# W = torch.randn(m, l, device="cuda").requires_grad_() * 0.01

# def func(x, w):
#     return x @ w

# # res = func(X, W)  # require large memory

# Xs = torch.chunk(X, chunks, dim=1)
# Ws = torch.chunk(W, chunks, dim=0)
# arr = []
# for x, w in zip(Xs, Ws):
#     # arr.append(func(x, w)) # also require large memory 
#     arr.append(checkpoint(func, x, w))  # checkpoint release the intermediate variables
# res = torch.stack(arr, dim=0).sum(dim=0).exp()

# print(torch.cuda.memory_allocated())