Differentiable Decision Tree via "ReLU+Argmin" Reformulation

Qiangqiang Mao, Jiayang Ren, Yixiu Wang, Chenxuanyin Zou, Jingjing Zheng, Yankai Cao

Advances in Neural Information Processing Systems 38 (NeurIPS 2025) Main Conference Track

Decision tree, despite its unmatched interpretability and lightweight structure, faces two key issues that limit its broader applicability: non-differentiability and low testing accuracy. This study addresses these issues by developing a differentiable oblique tree that optimizes the entire tree using gradient-based optimization. We propose an exact reformulation of hard-split trees based on "ReLU+Argmin" mechanism, and then cast the reformulated tree training as an unconstrained optimization task. The ReLU-based sample branching, expressed as exact-zero or non-zero values, preserve a unique decision path, in contrast to soft decision trees with probabilistic routing. The subsequent Argmin operation identifies the unique zero-violation path, enabling deterministic predictions. For effective gradient flow, we approximate Argmin behaviors by scaling softmin function. To ameliorate numerical instability, we propose a warm-start annealing scheme that solves multiple optimization tasks with increasingly accurate approximations. This reformulation alongside distributed GPU parallelism offers strong scalability, supporting 12-depth tree even on million-scale datasets where most baselines fail. Extensive experiments demonstrate that our optimized tree achieves a superior testing accuracy against 14 baselines, including an average improvement of 7.54\% over CART.