import torch
from learning.fns import to_dense_einsum

dtype = torch.float64
# α, β, γ, δ, ε, φ, ρ = [1.0, 0., 0., 0., 1.0, 0., 0.25]
# α, β, γ, δ, ε, φ, ρ = [0.5, 0., 0.5, 0., 0.5, 0.5, 0]
# α, β, γ, δ, ε, φ, ρ = [0.5, 0.5, 0, 0.5, 0.5, 0, 0]
# α, β, γ, δ, ε, φ, ρ = [0.8, 0., 0.2, 0., 0.8, 0.2, 0]
# α, β, γ, δ, ε, φ, ρ = [0.2, 0., 0.8, 0., 0.2, 0.8, 0]
α, β, γ, δ, ε, φ, ρ = [0.6, 0.3, 0.1, 0.15, 0.7, 0.15, 0]
assert abs(α + β + γ - 1) < 1.e-8
assert abs(δ + ε + φ - 1) < 1.e-8
rank_pred = min(1, 2 + ρ - α - ε)
d = 900
exps = [α, β, γ, δ, ε, φ, ρ]
vec = [round(d**exp) for exp in exps]
α, β, γ, δ, ε, φ, ρ = vec
vec = α, β, γ, δ, ε, φ, ρ = [59, 8, 2, 3, 117, 3, 1]
rank_pred = min(d, β * γ * δ * φ * ρ)
A = torch.randn(φ, δ, ρ, γ, α)
B = torch.randn(φ, ε, ρ, γ, β)
W = to_dense_einsum(A, B)
rank = torch.linalg.matrix_rank(W)
# print(f"Rank={rank} | Shape={tuple(W.shape)} | Rank_hat={round(d**rank_pred)}")
print(f"Rank={rank} | Shape={tuple(W.shape)} | Rank_hat={rank_pred}")
