{"title": "Extreme Classification in Log Memory using Count-Min Sketch: A Case Study of Amazon Search with 50M Products", "book": "Advances in Neural Information Processing Systems", "page_first": 13265, "page_last": 13275, "abstract": "In the last decade, it has been shown that many hard AI tasks, especially in NLP, can be naturally modeled as extreme classification problems leading to improved precision. However, such models are prohibitively expensive to train due to the memory bottleneck in the last layer. For example, a reasonable softmax layer for the dataset of interest in this paper can easily reach well beyond 100 billion parameters (> 400 GB memory). To alleviate this problem, we present Merged-Average Classifiers via Hashing (MACH), a generic $K$-classification algorithm where memory provably scales at $O(\\log K)$ without any assumption on the relation between classes. MACH is subtly a count-min sketch structure in disguise, which uses universal hashing to reduce classification with a large number of classes to few embarrassingly parallel and independent classification tasks with a small (constant) number of classes. MACH naturally provides a technique for zero communication model parallelism. We experiment with 6 datasets; some multiclass and some multilabel, and show consistent improvement in precision and recall metrics compared to respective baselines. In particular, we train an end-to -end deep classifier on a private product search dataset sampled from Amazon Search Engine with 70 million queries and 49.46 million documents. MACH outperforms, by a significant margin, the state-of-the-art extreme classification models deployed on commercial search engines: Parabel and dense embedding models. Our largest model has 6.4 billion parameters and trains in less than 35 hrs on a single p3.16x machine. Our training times are 7-10x faster, and our memory footprints are 2-4x smaller than the best baselines. This training time is also significantly lower than the one reported by Google\u2019s mixture of experts (MoE) language model on a comparable model size and hardware.", "full_text": "Extreme Classi\ufb01cation in Log Memory using\n\nCount-Min Sketch: A Case Study of Amazon Search\n\nwith 50M Products\n\nTharun Medini\u2217\n\nQixuan Huang\n\nYiqiu Wang\n\nElectrical and Computer Engineering\n\nComputer Science\n\nComputer Science\n\nRice University\n\nHouston, TX 77005\n\nRice University\n\nMIT\n\nHouston, TX 77005\n\nCambridge, MA 02139\n\ntharun.medini@rice.edu\n\nqh5@rice.edu\n\nyiqiuw@mit.edu\n\nVijai Mohan\n\nAmazon Search\n\nPalo Alto, CA 94301\nvijaim@amazon.com\n\nAnshumali Shrivastava\n\nComputer Science\n\nRice University\n\nHouston, TX 77005\n\nanshumali@rice.edu\n\nAbstract\n\nIn the last decade, it has been shown that many hard AI tasks, especially in NLP,\ncan be naturally modeled as extreme classi\ufb01cation problems leading to improved\nprecision. However, such models are prohibitively expensive to train due to the\nmemory blow-up in the last layer. For example, a reasonable softmax layer for\nthe dataset of interest in this paper can easily reach well beyond 100 billion\nparameters (> 400 GB memory). To alleviate this problem, we present Merged-\nAverage Classi\ufb01ers via Hashing (MACH), a generic K-classi\ufb01cation algorithm\nwhere memory provably scales at O(log K) without any strong assumption on\nthe classes. MACH is subtly a count-min sketch structure in disguise, which\nuses universal hashing to reduce classi\ufb01cation with a large number of classes\nto few embarrassingly parallel and independent classi\ufb01cation tasks with a small\n(constant) number of classes. MACH naturally provides a technique for zero\ncommunication model parallelism. We experiment with 6 datasets; some multiclass\nand some multilabel, and show consistent improvement over respective state-of-\nthe-art baselines. In particular, we train an end-to-end deep classi\ufb01er on a private\nproduct search dataset sampled from Amazon Search Engine with 70 million\nqueries and 49.46 million products. MACH outperforms, by a signi\ufb01cant margin,\nthe state-of-the-art extreme classi\ufb01cation models deployed on commercial search\nengines: Parabel and dense embedding models. Our largest model has 6.4 billion\nparameters and trains in less than 35 hours on a single p3.16x machine. Our training\ntimes are 7-10x faster, and our memory footprints are 2-4x smaller than the best\nbaselines. This training time is also signi\ufb01cantly lower than the one reported by\nGoogle\u2019s mixture of experts (MoE) language model on a comparable model size\nand hardware.\n\nIntroduction\n\n1\nThe area of extreme classi\ufb01cation has gained signi\ufb01cant interest in recent years [7, 19, 2]. In the\nlast decade, it has been shown that many hard AI problems can be naturally modeled as massive\nmulticlass or multilabel problems leading to a drastic improvement over prior work. For example,\npopular NLP models predict the best word, given the full context observed so far. Such models are\nbecoming the state-of-the-art in machine translation [22], word embeddings [16], question answering,\netc. For a large dataset, the vocabulary size can quickly run into billions [16]. Similarly, Information\n\n\u2217Part of this work done while interning at Amazon Search, Palo Alto, CA\n\n33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.\n\n\fRetrieval, the backbone of modern Search Engines, has increasingly seen Deep Learning models\nbeing deployed in real Recommender Systems [18, 25].\n\nHowever, the scale of models required by above tasks makes it practically impossible to train\na straightforward classi\ufb01cation model, forcing us to use embedding based solutions [17, 23, 4].\nEmbedding models project the inputs and the classes onto a small dimensional subspace (thereby\nremoving the intractable last layer). But they have two main bottlenecks: 1) the optimization is done\non a pairwise loss function leading to a huge number of training instances (each input-class pair is a\ntraining instance) and so negative sampling [16] adds to the problem, 2) The loss functions are based\non handcrafted similarity thresholds which are not well understood. Using a standard cross-entropy\nbased classi\ufb01er instead intrinsically solves these two issues while introducing new challenges.\n\nOne of the datasets used in this paper is an aggregated and sampled product search dataset from\nAmazon search engine which consists of 70 million queries and around 50 million products. Consider\nthe popular and powerful p3.16x machine that has 8 V-100 GPUs each with 16GB memory. Even\nfor this machine, the maximum number of 32-bit \ufb02oating point parameters that we can have is 32\nbillion. If we use momentum based optimizers like Adam [14], we would need 3x parameters for\ntraining because Adam requires 2 auxiliary variables per parameter. That would technically limit\nour network parameter space to \u2248 10 billion. A simplest fully connected network with a single\nhidden layer of 2000 nodes (needed for good accuracy) and an output space of 50 million would\nrequire 2000 \u00d7 50 \u00d7 106 = 100 billion parameters without accounting for the input-to-hidden layer\nweights and the data batches required for training. Such model will need > 1.2 TB of memory for\nthe parameter only with Adam Optimizer.\n\nModel-Parallelism requires communication: The above requirements are unreasonable to train\nan end-to-end classi\ufb01cation model even on a powerful and expensive p3.16x machine. We will\nneed sophisticated clusters and distributed training like [20]. Training large models in distributed\ncomputing environments is a sought after topic in large scale learning. The parameters of giant models\nare required to be split across multiple nodes. However, this setup requires costly communication and\nsynchronization between the parameter server and processing nodes in order to transfer the gradient\nand parameter updates. The sequential nature of gradient updates prohibits ef\ufb01cient sharding of the\nparameters across computer nodes. Ad-hoc model breaking is known to hurt accuracy.\n\nContrast with Google\u2019s MoE model: A notable work in large scale distributed deep learning with\nmodel parallelism is Google\u2019s \u2018Sparsely-Gated Mixture of Experts\u2019 [21]. Here, the authors mix smart\ndata and model parallelism techniques to achieve fast training times for super-huge networks (with up\nto 137 billion parameters). One of their tasks uses the 1 billion word language modelling dataset\nthat has \u2248 30 million unique sentences with a total of 793K words which is much smaller than our\nproduct search dataset. One of the con\ufb01gurations of \u2018MoE\u2019 has around 4.37 billion parameters which\nis smaller than our proposed model size (we use a total of 6.4 billion, as we will see in section 4.2).\nUsing 32 k40 GPUs, they train 10 epochs in 47 hrs. Our model trains 10 epochs on 8 V100 GPUs\n(roughly similar computing power) in just 34.2 hrs. This signi\ufb01es the impact of zero-communication\ndistributed parallelism for training outrageously large networks, which to the best of our knowledge\nis only achieved by MACH.\n\nOur Contributions: We propose a simple hashing based divide-and-conquer algorithm MACH\n(Merged-Average Classi\ufb01cation via Hashing) for K-class classi\ufb01cation, which only requires\nO(d log K) model size (memory) instead of O(Kd) required by traditional linear classi\ufb01ers (d\nid the dimension of the penultimate layer). MACH also provides computational savings by requiring\nonly O(Bd log K + K log K) (B is a constant that we\u2019ll see later) multiplications during inference\ninstead of O(Kd) for the last layer.\n\nFurthermore, the training process of MACH is embarrassingly parallelizable obviating the need of\nany sophisticated model parallelism. We provide strong theoretical guarantees (section C in appendix)\nquantifying the trade-offs between computations, accuracy and memory. In particular, we show that\nin log K\nd memory, MACH can discriminate between any two pair of classes with probability 1 \u2212 \u03b4.\n\u221a\u03b4\nOur analysis provides a novel treatment for approximation in classi\ufb01cation via a distinguishability\nproperty between any pair of classes.\n\nWe do not make any strong assumptions on the classes. Our results are for any generic K-class\nclassi\ufb01cation without any relations, whatsoever, between the classes. Our idea takes the existing\nconnections between extreme classi\ufb01cation (sparse output) and compressed sensing, pointed out\n\n2\n\n\fin [12], to an another level in order to avoid storing costly sensing matrix. We comment about this in\ndetail in section 3. Our formalism of approximate multiclass classi\ufb01cation and its strong connections\nwith count-min sketches [8] could be of independent interest in itself.\n\nWe experiment with multiclass datasets ODP-105K and \ufb01ne-grained ImageNet-22K; multilabel\ndatasets Wiki10-31K , Delicious-200K and Amazon-670K; and an Amazon Search Dataset with 70M\nqueries and 50M products. MACH achieves 19.28% accuracy on the ODP dataset with 105K classes\nwhich is the best reported so far on this dataset, previous best being only 9% [9]. To achieve around\n15% accuracy, the model size with MACH is merely 1.2GB compared to around 160GB with the\none-vs-all classi\ufb01er that gets 9% and requires high-memory servers to process heavy models. On the\n50M Search dataset, MACH outperforms the best extreme multilabel classi\ufb01cation technique Parabel\nby 6% on weighted recall and Deep Semantic Search Model (DSSM, a dense embedding model\ntested online on Amazon Search) by 21% (details mentioned in section 4.2.2). We corroborate the\ngeneralization of MACH by matching the best algorithms like Parabel and DisMEC on P@1, P@3\nand P@5 on popular extreme classi\ufb01cation datasets Amazon-670K, Delicious-200K and Wiki10-31K.\n2 Background\nWe will use the standard [] for integer range, i.e., [l] denotes the integer set from 1 to l: [l] =\n{1, 2, \u00b7 \u00b7 \u00b7 , l}. We will use the standard logistic regression settings for analysis. The data D is given\nby D = (xi, yi)N\ni=1. xi \u2208 Rd will be d dimensional features and yi \u2208 {1, 2, \u00b7 \u00b7 \u00b7 , K}, where K\ndenotes the number of classes. We will drop the subscript i for simplicity whenever we are talking\nabout a generic data point and only use (x, y). Given an x, we will denote the probability of y (label)\ntaking the value i, under the given classi\ufb01er model, as pi = P r(y = i|x).\n\n2-Universal Hashing: A randomized function h : [l] \u2192 [B] is 2-universal if for all, i, j \u2208 [l] with\ni 6= j, we have the following property for any z1, z2 \u2208 [k]\n\nP r(h(i) = z1 and h(j) = z2) =\n\n1\nB2\n\n(1)\n\nAs shown in [5], the simplest way to create a 2-universal hashing scheme is to pick a prime number\np \u2265 B, sample two random numbers a, b uniformly in the range [0, p] and compute h(x) =\n((ax + b) mod p) mod B.\n\nCount-Min Sketch: Count-Min Sketch is a widely used approximate counting algorithm to identify\nthe most frequent elements in a huge stream that we do not want to store in memory.\n\nAssume that we have a stream a1, a2, a3..... where there could be repetition of elements. We\nwould like to estimate how many times each distinct element has appeared in the stream. The\nstream could be very long and the number of distinct elements K could be large. In Count-Min\nSketch [8], we basically assign O(log K) \u2018signatures\u2019 to each class using 2-universal hash functions.\nWe use O(log K) different hash functions H1, H2, H3, ..., HO(log K), each mapping any class i to\na small range of buckets B << K, i.e., Hj(i) \u2208 [B]. We maintain a counting-matrix C of order\nO(log K) \u2217 B. If we encounter class i in the stream of classes, we increment the counts in cells\nH1(i), H2(i)....., HO(log K)(i). It is easy to notice that there will be collisions of classes into these\ncounting cells. Hence, the counts for a class in respective cells could be over-estimates.\n\nDuring inference, we want to know the frequency of a particular element say a1. We sim-\nply go to all the cells where a1 is mapped to. Each cell gives and over-estimated value\nof the original frequency of a1. To reduce the offset of estimation, the algorithm proposes\nto take the minimum of all the estimates as the approximate frequency, i.e., napprox(a1) =\nmin(C[1, H1(i)], C[2, H2(i), ...., C[log K, Hlog K]). An example illustration of Count-Min Sketch\nis given in \ufb01gure 1 in appendix.\n\n3 Our Proposal: Merged-Averaged Classi\ufb01ers via Hashing (MACH)\nMACH randomly merges K classes into B random-meta-classes or buckets (B is a small, manageable\nnumber). We then run any off-the-shelf classi\ufb01er, such as logistic regression or a deep network, on\nthis meta-class classi\ufb01cation problem. We then repeat the process independently R = O(log K)\ntimes each time using an independent 2-universal hashing scheme. During prediction, we aggregate\nthe output from each of the R classi\ufb01ers to obtain the predicted class. We show that this simple\nscheme is theoretically sound and only needs log K memory in Theorem 2. We present Information\ntheoretic connections of our scheme with compressed sensing and heavy hitters. Figure 1 broadly\nexplains our idea.\n\n3\n\n\f: [K] \u2192 [B], i =\nFormally, we use R, independently chosen, 2-universal hash functions hi\n{1, 2, \u00b7 \u00b7 \u00b7 , R}. Each hi uniformly maps the K classes into one of the B buckets. B and R are our\nparameters that we can tune to trade accuracy with both computations and memory. B is usually a\nsmall constant like 10 or 50. Given the data {xi, yi}N\ni=1, it is convenient to visualize that each hash\nfunction hj , transforms the data D into Dj = {xi, hj(yi)}N\ni=1. We do not have to materialize the\nhashed class values for all small classi\ufb01ers, we can simply access the class values through hj . We\nthen train R classi\ufb01ers, one on each of these Dj \u2019s to get R models Mj s. This concludes our training\nprocess. Note that each hj is independent. Training R classi\ufb01ers is trivially parallelizable across R\nmachines or GPUs.\n\nWe need a few more notations. Each meta-classi\ufb01er can only classify among the merged meta-classes.\nLet us denote the probability of the meta-class b \u2208 [B], with the jth classi\ufb01er with capitalized P j\nb . If\nthe meta-class contains the class i, i.e. hj(i) = b, then we can also write it as P j\n\nhj (i).\n\nBefore we describe the prediction phase, we have following theorem (proved in Section C in appendix).\n\nTheorem 1\n\nE(cid:20) B\n\nB \u2212 1(cid:20) 1\n\nR\n\nR\n\nXj=1\n\nP j\nhj (i)(x) \u2212\n\n1\n\nB(cid:21)(cid:21) = P r(cid:18)y = i(cid:12)(cid:12)(cid:12)(cid:12)\n\nx(cid:19) = pi\n\n(2)\n\nfor given x.\n\nB\n\nIn theorem 1, P j\nhj (i)(x) is the predicted proba-\nbility of meta-class hj(i) under the jth model\n(Mj ),\nIt\u2019s easy to observe\nthat the true probability of a class grows lin-\nearly with the sum of individual meta-class\nR\u2217(B\u22121) and\nprobabilities (with multiplier of\na shift of \u22121\nThus, our classi\ufb01cation\nrule is given by arg maxi P r(y = i|x) =\nhj (i)(x). The pseudocode for\nboth training and prediction phases is given in\nAlgorithms 1 and 2 in appendix.\n\narg maxiPR\n\nB\u22121 ).\nj=1 P j\n\nFigure 1: Outline of MACH. We hash each class\ninto B bins using a 2-universal hash function. We\nuse R different hash functions and assign different\nsignatures to each of the K classes. We then train\nR independent B class classi\ufb01ers (B << K)\n\nClearly,\nthe total model size of MACH is\nO(RBd) to store R models of size Bd each.\nThe prediction cost requires RBd multiplica-\ntions to get meta probabilities, followed by KR\nto compute equation 1 for each of the classes.\nThe argmax can be calculated on the \ufb02y. Thus,\nthe total cost of prediction is RBd + KR. Since\nR models are independent, both the training and\nprediction phases are conducive to trivial par-\nallellization. Hence, the overall inference time\ncan be brought down to Bd + KR.\n\nTo obtain signi\ufb01cant savings on both model size\nand computation, we want BR \u226a K. The sub-\nsequent discussion shows that BR \u2248 O(log K)\nis suf\ufb01cient for identifying the \ufb01nal class with\nhigh probability.\n\nDe\ufb01nition 1 Indistinguishable Class Pairs: Given any two classes c1 and c2 \u2208 [K], they are\nindistinguishable under MACH if they fall in the same meta-class for all the R hash functions, i.e.,\nhj(c1) = hj(c2) for all j \u2208 [R].\n\nOtherwise, there is at least one classi\ufb01er which provides discriminating information between them.\nGiven that the only sources of randomness are the independent 2-universal hash functions, we can\nhave the following lemma:\n\n4\n\n\fLemma 1 Using MACH with R independent B-class classi\ufb01er models, any two original classes c1\nand c2 \u2208 [K] will be indistinguishable with probability at most\n\nP r(classes i and j are indistinguishable) \u2264 (cid:18) 1\n\nB(cid:19)R\n\n(3)\n\nThere are total K(K\u22121)\none pair of classes, which is indistinguishable under MACH is given by the union bound as\n\n\u2264 K 2 possible pairs, and therefore, the probability that there exist at least\n\n2\n\nP r(\u2203 an indistinguishable pair) \u2264 K 2(cid:18) 1\nB(cid:19)R\n\nB(cid:19)R\n\n(4)\n\nThus, all we need is K 2(cid:18) 1\n\n\u2264 \u03b4 to ensure that there is no indistinguishable pair with probability\n\n\u2265 1 \u2212 \u03b4. Overall, we get the following theorem:\n\nTheorem 2 For any B, R =\n(not indistinguishable) from each other with probability greater than 1 \u2212 \u03b4.\n\n2 log K\n\u221a\u03b4\nlog B guarantees that all pairs of classes ci and cj are distinguishable\n\nThe extension of MACH to multilabel setting is quite straightforward as all that we need is to change\nsoftmax cross-entropy loss to binary cross-entropy loss. The training and evaluation is similar to the\nmulticlass classi\ufb01cation.\n\nConnections with Count-Min Sketch: Given a data instance x, a vanilla classi\ufb01er outputs the\nprobabilities pi, i \u2208 {1, 2, ..., K}. We want to essentially compress the information of these K\nnumbers to log K measurements. In classi\ufb01cation, the most informative quantity is the identity\nof arg max pi.\nIf we can identify a compression scheme that can recover the high probability\nclasses from smaller measurement vector, we can train a small-classi\ufb01er to map an input to these\nmeasurements instead of the big classi\ufb01er.\n\nThe foremost class of models to accomplish this task are Encoder and Decoder based models like\nCompressive Sensing [3]. The connection between compressed sensing and extreme classi\ufb01cation\nwas identi\ufb01ed in prior works [12, 10]. In [12], the idea was to use a compressed sensing matrix to\ncompress the K dimensional binary indicator vector of the class yi to a real number and solve a\nregression problem. While Compressive Sensing is theoretically very compelling, recovering the\noriginal predictions is done through iterative algorithms like Iteratively Re-weighted Least Squares\n(IRLS)[11] which are prohibitive for low-latency systems like online predictions. Moreover, the\nobjective function to minimize in each iteration involves the measurement matrix A which is by itself\na huge bottleneck to have in memory and perform computations. This defeats the whole purpose of\nour problem since we cannot afford O(K \u2217 log K) matrix.\n\nWhy only Count-Min Sketch? : Imagine a set of classes {cats, dogs, cars, trucks}. Suppose\nwe want to train a classi\ufb01er that predicts a given compressed measurement of classes: {0.6 \u2217 pcars +\n0.4 \u2217 p(cats), 0.5 \u2217 p(dogs) + 0.5 \u2217 p(trucks)}, where p(class) denotes the probability value of\nclass. There is no easy way to predict this without training a regression model. Prior works attempt\nto minimize the norm between the projections of true (0/1) K-vector and the predicted log K-vectors\n(like in the case of [12]). For a large K, errors in regression is likely to be very large.\n\nOn the other hand, imagine two meta classes {[cars & trucks], [cats & dogs]}. It is easier for a\nmodel to learn how to predict whether a data point belongs to \u2018cars & trucks\u2019 because the probability\nassigned to this meta-class is the sum of original probabilities assigned to cars and trucks. By virtue\nof being a union of classes, a softmax-loss function works very well. Thus, a subtle insight is that\nonly (0/1) design matrix for compressed sensing can be made to work here. This is precisely why a\nCM sketch is ideal.\n\nIt should be noted that another similar alternative Count-Sketch [6] uses [\u22121, 0, 1] design matrix. This\nformulation creates meta-classes of the type [cars & not trucks] which cannot be easily estimated.\n4 Experiments\nWe experiment with 6 datasets whose description and statistics are shown in table 1 in appendix. The\ntraining details and P@1,3,5 on 3 multilabel datasets Wiki10-31K, Delicious-200K and Amazon-\n670K are also discussed in section D.3 in appendix. The brief summary of multilabel results is that\n\n5\n\n\fDataset\n\n(B, R)\n\nODP\nImagenet\n\n(32, 25)\n(512, 20)\n\nsize\n\nModel\nReduction\n125x\n2x\n\nTraining Time\n\n7.2hrs\n23hrs\n\nPrediction Time per\nQuery\n2.85ms\n8.5ms\n\nAccuracy\n\n15.446%\n10.675%\n\nTable 1: Wall Clock Execution Times and accuracies for two runs of MACH on a single Titan X.\n\nMACH consistently outperforms tree-based methods like FastXML [19] and PfastreXML [13] by\nnoticeable margin. It mostly preserves the precision achieved by the best performing algorithms like\nParabel [18] and DisMEC [2] and even outperforms them on half the occasions.\n4.1 Multiclass Classi\ufb01cation\nWe use the two large public benchmark datasets ODP and ImageNet from [9].\n\nAll our multiclass experiments were performed on the same server with Geforce GTX TITAN X,\nIntel(R) Core(TM) i7-5960X 8-core CPU @ 3.00GHz and 64GB memory. We used Tensor\ufb02ow [1]\nto train each individual model Mi and obtain the probability matrix Pi from model Mi. We use\nOpenCL to compute the global score matrix that encodes the score for all classes [1, K] in testing\ndata and perform argmax to \ufb01nd the predicted class. Our codes and scripts are hosted at the repository\nhttps://github.com/Tharun24/MACH/.\n\nFigure 2: Accuracy Resource tradeoff with MACH (bold lines) for various settings of R and B. The\nnumber of parameters are BRd while the prediction time requires KR + BRd operations. All the\nruns of MACH requires less memory than OAA. The straight line are accuracies of OAA, LOMTree\nand Recall Tree (dotted lines) on the same partition taken from [9]. LOMTree and Recall Tree uses\nmore (around twice) the memory required by OAA. Left: ODP Dataset. Right: Imagenet Dataset\n\n4.1.1 Accuracy Baselines\nOn these large benchmarks, there are three published methods that have reported successful evalua-\ntions \u2013 1) OAA, traditional one-vs-all classi\ufb01ers, 2) LOMTree and 3) Recall Tree. The results of\nall these methods are taken from [9]. OAA is the standard one-vs-all classi\ufb01ers whereas LOMTree\nand Recall Tree are tree-based methods to reduce the computational cost of prediction at the cost of\nincreased model size. Recall Tree uses twice as much model size compared to OAA. Even LOMtree\nhas signi\ufb01cantly more parameters than OAA. Thus, our proposal MACH is the only method that\nreduces the model size compared to OAA.\n4.1.2 Results and Discussions\nWe run MACH on these two datasets varying B and R. We used plain cross entropy loss without any\nregularization. We plot the accuracy as a function of different values of B and R in Figure 2. We use\nthe unbiased estimator given by Equation 1 for inference as it is superior to other estimators (See\nsection D.2 in appendix for comparison with min and median estimators).\n\nThe plots show that for ODP dataset MACH can even surpass OAA achieving 18% accuracy while\nthe best-known accuracy on this partition is only 9%. LOMtree and Recall Tree can only achieve\n6-6.5% accuracy. It should be noted that with 100,000 classes, a random accuracy is 10\u22125. Thus,\nthe improvements are staggering with MACH. Even with B = 32 and R = 25, we can obtain more\nthan 15% accuracy with 105,000\n32\u00d725 = 120 times reduction in the model size. Thus, OAA needs 160GB\nmodel size, while we only need around 1.2GB. To get the same accuracy as OAA, we only need\nR = 50 and B = 4, which is a 480x reduction in model size requiring mere 0.3GB model \ufb01le.\n\nOn ImageNet dataset, MACH can achieve around 11% which is roughly the same accuracy of\nLOMTree and Recall Tree while using R = 20 and B = 512. With R = 20 and B = 512, the\n\n6\n\n\fmemory requirement is 21841\n512\u00d720 = 2 times less than that of OAA. On the contrary, Recall Tree and\nLOMTree use 2x more memory than OAA. OAA achieves the best result of 17%. With MACH, we\ncan run at any memory budget.\n\nIn table 1, we have compiled the running time of some of the reasonable combination and have shown\nthe training and prediction time. The prediction time includes the work of computing probabilities of\nmeta-classes followed by sequential aggregation of probabilities and \ufb01nding the class with the max\nprobability. The wall clock times are signi\ufb01cantly faster than the one reported by RecallTree, which\nis optimized for inference.\n4.2\nAfter corroborating MACH\u2019s applicability on large public extreme classi\ufb01cation datasets, we move\non to the much more challenging real Information Retrieval dataset with 50M classes to showcase\nthe power of MACH at scale. As mentioned earlier, we use an aggregated and sub-sampled search\ndataset mapping queries to product purchases. Sampling statistics are hidden to respect Amazon\u2019s\ndisclosure policies.\n\nInformation Retrieval with 50 million Products\n\nThe dataset has 70.3 M unique queries and 49.46 M products. For every query, there is atleast one\npurchase from the set of products. Purchases have been amalgamated from multiple categories and\nthen uniformly sampled. The average number of products purchased per query is 2.1 and the average\nnumber of unique queries per product is 14.69.\n\nFor evaluation, we curate another 20000 unique queries with atleast one purchase among the afore-\nmentioned product set. These transactions sampled for evaluation come from a time-period that\nsucceeds the duration of the training data. Hence, there is no temporal overlap between the transac-\ntions. Our goal is to measure whether our top predictions contain the true purchased products, i.e.,\nwe are interested in measuring the purchase recall.\n\nFor measuring the performance on Ranking, for each of the 20000 queries, we append \u2018seen but not\npurchased\u2019 products along with purchased products. To be precise, every query in the evaluation\ndataset has a list of products few of which have been purchased and few others that were clicked but\nnot purchased (called \u2018seen negatives\u2019). On an average, each of the 20000 queries has 14 products of\nwhich \u2248 2 are purchased and the rest are \u2018seen negatives\u2019. Since products that are only \u2018seen\u2019 are\nalso relevant to the query, it becomes challenging for a model to selectively rank purchases higher\nthan another related products that were not purchased. A good model should be able to identify these\nsubtle nuances and rank purchases higher than just \u2018seen\u2019 ones.\n\nEach query in the dataset has sparse feature representation of 715K comprising of 125K frequent word\nunigrams, 20K frequent bigrams and 70K character trigrams and 500K reserved slots for hashing\nout-of-vocabulary tokens.\n\nArchitecture Since MACH fundamentally trains many small models, an input dimension of 715K is\ntoo large. Hence, we use sklearn\u2019s murmurhash3_32 package and perform feature hashing [24]\nto reduce the input dimension to 80K (empirically observed to have less information loss). We use\na feed forward neural network with the architecture 80K-2K-2K-B for each of R classi\ufb01ers. 2000\nis the embedding dimension for a query, another 2000 is the hidden layer dimension and the \ufb01nal\noutput layer is B dimensional where we report the metrics with B = 10000 and B = 20000. For\neach B, we train a maximum of 32 repetitions ,i.e., R = 32. We show the performance trend as R\ngoes from 2,4,8,16,32.\n\nMetrics Although we pose the Search problem as a multilabel classi\ufb01cation model, the usual precision\nmetric is not enough to have a clear picture. In product retrieval, we have a multitude of metrics in\nconsideration (all metrics of interest are explained in section E in appendix). Our primary metric is\nweighted Recall@100 (we get the top 100 predictions from our model and measure the recall) where\nthe weights come from number of sessions. To be precise, if a query appeared in a lot of sessions,\nwe prioritize the recall on those queries as opposed to queries that are infrequent/unpopular. For\nexample, if we only have 2 queries q1 and q2 which appeared in n1 and n2 sessions respectively. The\nwRecall@100 is given by (recall@100(q1) \u2217 n1 + recall@100(q2) \u2217 n2)/(n1 + n2). Un-weighted\nrecall is the simple mean of Recall@100 of all the queries.\n4.2.1 Baselines\nParabel: A natural comparison would arise with the recent algorithm \u2018Parabel\u2019 [18] as it has been\nused in Bing Search Engine to solve a 7M class challenge. We compare our approach with the publicly\navailable code of Parabel. We vary the number of trees among 2,4,8,16 and chose the maximum\n\n7\n\n\fnumber of products per leaf node to vary among 100, 1000 and 8000. We have experimented with a\nfew con\ufb01gurations of Parabel and \ufb01gured out that the con\ufb01guration with 16 trees each with 16300\nnodes (setting the maximum number of classes in each leaf node to 8000) gives the best performance.\nIn principle, number of trees in Parabel can be perceived as number of repetitions in MACH. Similarly,\nnumber of nodes in each tree in Parabel is equivalent to number of buckets B in MACH. We could\nnot go beyond 16 trees in Parabel as the memory consumption was beyond limits (see Table 2).\n\nDeep Semantic Search Model (DSSM): We tried running the publicly available C++ code of\nAnnexML [23] (graph embedding based model) by varying embedding dimension and number of\nlearners. But none of the con\ufb01gurations could show any progress even after 5 days of training. The\nnext best embedding model SLEEC [4] has a public MATLAB code but it doesn\u2019t scale beyond 1\nmillion classes (as shown in extreme classi\ufb01cation repository [15]).\n\nIn the wake of these scalability challenges, we chose to compare against a dense embedding model\nDSSM [17] that was A/B tested online on Amazon Search Engine. This custom model learns an\nembedding matrix that has a 256 dimensional dense vectors for each token (tokenized into word\nunigrams, word bigrams, character trigrams as mentioned earlier). This embedding matrix is shared\nacross both queries and products. Given a query, we \ufb01rst tokenize it, perform a sparse embedding\nlookup from the embedding matrix and average the vectors to yield a vector representation. Similarly,\ngiven a product (in our case, we use the title of a product), we tokenize it and perform sparse\nembedding lookup and average the retrieved vectors to get a dense representation. For every query,\npurchased products are deemed to be highly relevant. These product vectors are supposed to be\n\u2018close\u2019 to the corresponding query vectors (imposed by a loss function). In addition to purchased\nproducts, 6x number of random products are sampled per query. These random products are deemed\nirrelevant by a suitable loss function.\n\nObjective Function: All the vectors are unit normalized and the cosine similarity between two\nvectors is optimized. For a query-product pair that is purchased, the objective function enforces the\ncosine similarity to be > \u03b8p (p for purchased). For a query-product pair that\u2019s deemed to be irrelevant,\nthe cosine similarity is enforced to be < \u03b8r (r for random).\n\nGiven the cosine similarity s between a query-document pair and a label l (indicating p, r), the overall\nloss function is given as loss(s, l) = I p(l) \u2217 min2(0, s \u2212 \u03b8p) + I r(l) \u2217 max2(0, s \u2212 \u03b8r) for the\nembedding model. The thresholds used during online testing were \u03b8p = 0.9 and \u03b8r = 0.1.\n\n4.2.2 Results\n\n3\n\nshows\n\nthe\n\nrecall@100\n\nFigure\nfor\nR=2,4,8,16,32 after 1,5 and 10 epochs re-\nspectively. The dotted red/blue/green lines\ncorrespond to MACH with B = 20K and\nthe solid red/blue/green lines correspond to\nB = 10K. The cyan and magenta lines\ncorrespond to Parabel algorithm with the\nnumber of trees being 2,4,8,16 end epochs\nbeing 1 and 5 respectively. We couldn\u2019t go\nbeyond 16 trees for Parabel because the peak\nmemory consumption during both training and\ntesting was reaching the limits of the AWS\np3.16x machine that we used (64 vcpus, 8\nV-100 NVIDIA Tesla GPUs, 480 GB Memory).\nThe yellow line corresponds to the dense\nembedding based model. The training time and\nmemory consumption is given in table 2.\n\nFigure 3: Comparison of our proposal MACH to\nParabel and Embedding model.\n\nSince calculating all Matching and Ranking met-\nrics with the entire product set of size 49.46 M is cumbersome, we come up with a representative\ncomparison by limiting the products to just 1 M. Across all the 20000 queries in the evaluation dataset,\nthere are 32977 unique purchases. We \ufb01rst retain these products and then sample the remaining\n967023 products randomly from the 49.46 M products. Then we use our model and the baseline\nmodels to obtain a 20K*1 M score matrix. We then evaluate all the metrics on this sub-sampled\nrepresentative score matrix. Tables 3 and 4 show the comparison of various metrics for MACH vs\nParabel vs Embedding model.\n\n8\n\n\f256\n\nModel\nDSSM,\ndim\nParabel,\nnum_trees=16\nMACH,\nB=10K, R=32\nMACH,\nB=20K, R=32\n\nepochs wRecall\n5\n\n0.441\n\nTotal training time\n316.6 hrs\n\nMemory(Train) Memory (Eval)\n40 GB\n\n286 GB\n\n#Params\n200 M\n\n5\n\n10\n\n10\n\n0.5810\n\n0.6419\n\n0.6541\n\n232.4 hrs (all 16 trees\nin parallel)\n31.8 hrs (all 32 repe-\ntitions in parallel)\n34.2 hrs (all 32 repe-\ntitions in parallel)\n\n350 GB\n\n426 GB\n\n-\n\n150 GB\n\n180 GB\n\n80 GB\n\n90 GB\n\n5.77 B\n\n6.4 B\n\nTable 2: Comparison of the primary metric weighted_Recall@100, training time and peak memory\nconsumption of MACH vs Parabel vs Embedding Model. We could only train 16 trees for Parabel as\nwe reached our memory limits\n\nMetric\nmap_weighted\nmap_unweighted\nmrr_weighted\nmrr_unweighted\nndcg_weighted\nndcg_unweighted\nrecall_weighted\nrecall_unweighted\n\nEmbedding Parabel\n0.6335\n0.6419\n0.5210\n0.4802\n0.5596\n0.4439\n0.5066\n0.4658\n0.7567\n0.7792\n0.6058\n0.5925\n0.7509\n0.8391\n0.8968\n0.7717\n\nMACH, B=10K, R=32 MACH, B=20K, R=32\n0.6864\n0.4913\n0.5393\n0.4765\n0.7211\n0.5828\n0.8344\n0.7883\n\n0.7081\n0.5182\n0.5307\n0.5015\n0.7830\n0.6081\n0.8486\n0.8206\n\nTable 3: Comparison of Matching metrics for MACH vs Parabel vs Embedding Model. These metrics\nare for representative 1M products as explained.\n\nMetric\nEmbedding Parabel\nndcg_weighted\n0.7374\n0.7456\n0.6167\nndcg_unweighted\n0.6076\nmrr_weighted\n0.9180\n0.9196\nmrr_unweighted\n0.516\n0.5200\nmrr_most_rel_weighted\n0.5037\n0.5091\nmrr_most_rel_unweighted\n0.4693\n0.4671\nprec@1_weighted\n0.8788\n0.8744\nprec@1_unweighted\n0.3573\n0.3521\nprec@1_most_rel_weighted\n0.3776\n0.3741\nprec@1_most_rel_unweighted 0.3246\n0.3221\n\nMACH, B=10K, R=32 MACH, B=20K, R=32\n0.7769\n0.6072\n0.9414\n0.5200\n0.5146\n0.4681\n0.9109\n0.3667\n0.3989\n0.3365\n\n0.7749\n0.6144\n0.9419\n0.5293\n0.5108\n0.4767\n0.9102\n0.3702\n0.3989\n0.3460\n\nTable 4: Comparison of Ranking metrics. These metrics are for curated dataset where each query has\npurchases and \u2018seen negatives\u2019 as explained in 4.2. We rank purchases higher than \u2018seen negatives\u2019.\n\n4.2.3 Analysis\n\nMACH achieves considerably superior wRecall@100 compared to Parabel and the embedding model\n(table 2). MACH\u2019s training time is 7x smaller than Parabel and 10x smaller than embedding model\nfor the same number of epochs. This is expected because Parabel has a partial tree structure which\ncannot make use of GPUs like MACH. And the embedding model trains point wise loss for every\nquery-product pair unlike MACH which trains a multilabel cross-entropy loss per query. Since the\nquery-product pairs are huge, the training time is very high. Memory footprint while training is\nconsiderably low for embedding model because its training just an embedding matrix. But during\nevaluation, the same embedding model has to load all 256 dimensional vectors for products in\nmemory for a nearest neighbour based lookup. This causes the memory consumption to grow a lot\n(this is more concerning if we have limited GPUs). Parabel has high memory consumption both while\ntraining and evaluation.\n\nWe also note that MACH consistently outperforms other two algorithms on Matching and Ranking\nmetrics (tables 3 and 4). Parabel seems to be better on MRR for matching but the all important recall\nis much lower than MACH.\n\n9\n\n\fAcknowledgments\n\nThe work was supported by NSF-1652131, NSF-BIGDATA 1838177, AFOSR-YIPFA9550-18-1-\n0152, Amazon Research Award, and ONR BRC grant for Randomized Numerical Linear Algebra.\n\nWe thank Priyanka Nigam from Amazon Search for help with data pre-processing, running the\nembedding model baseline and getting Matching and Ranking Metrics. We also thank Choon-Hui Teo\nand SVN Vishwanathan for insightful discussions about MACH\u2019s connections to different Extreme\nClassi\ufb01cation paradigms.\n\nReferences\n\n[1] Mart\u00edn Abadi, Ashish Agarwal, Paul Barham, Eugene Brevdo, Zhifeng Chen, Craig Citro,\nGreg S Corrado, Andy Davis, Jeffrey Dean, Matthieu Devin, et al. Tensor\ufb02ow: Large-scale\nmachine learning on heterogeneous distributed systems. arXiv preprint arXiv:1603.04467,\n2016.\n\n[2] Rohit Babbar and Bernhard Sch\u00f6lkopf. Dismec: distributed sparse machines for extreme\nmulti-label classi\ufb01cation. In Proceedings of the Tenth ACM International Conference on Web\nSearch and Data Mining, pages 721\u2013729. ACM, 2017.\n\n[3] Richard G Baraniuk. Compressive sensing [lecture notes]. IEEE signal processing magazine,\n\n24(4):118\u2013121, 2007.\n\n[4] Kush Bhatia, Himanshu Jain, Purushottam Kar, Manik Varma, and Prateek Jain. Sparse\nlocal embeddings for extreme multi-label classi\ufb01cation. In Advances in Neural Information\nProcessing Systems, pages 730\u2013738, 2015.\n\n[5] J Lawrence Carter and Mark N Wegman. Universal classes of hash functions. In Proceedings\n\nof the ninth annual ACM symposium on Theory of computing, pages 106\u2013112. ACM, 1977.\n\n[6] Moses Charikar, Kevin Chen, and Martin Farach-Colton. Finding frequent items in data streams.\nIn International Colloquium on Automata, Languages, and Programming, pages 693\u2013703.\nSpringer, 2002.\n\n[7] Anna E Choromanska and John Langford. Logarithmic time online multiclass prediction. In\n\nAdvances in Neural Information Processing Systems, pages 55\u201363, 2015.\n\n[8] Graham Cormode and Shan Muthukrishnan. An improved data stream summary: the count-min\n\nsketch and its applications. Journal of Algorithms, 55(1):58\u201375, 2005.\n\n[9] Hal Daume III, Nikos Karampatziakis, John Langford, and Paul Mineiro. Logarithmic time\n\none-against-some. arXiv preprint arXiv:1606.04988, 2016.\n\n[10] Thomas G Dietterich and Ghulum Bakiri. Solving multiclass learning problems via error-\n\ncorrecting output codes. Journal of arti\ufb01cial intelligence research, 2:263\u2013286, 1995.\n\n[11] Peter J Green. Iteratively reweighted least squares for maximum likelihood estimation, and\nsome robust and resistant alternatives. Journal of the Royal Statistical Society. Series B\n(Methodological), pages 149\u2013192, 1984.\n\n[12] Daniel J Hsu, Sham M Kakade, John Langford, and Tong Zhang. Multi-label prediction via\ncompressed sensing. In Advances in neural information processing systems, pages 772\u2013780,\n2009.\n\n[13] Himanshu Jain, Yashoteja Prabhu, and Manik Varma. Extreme multi-label loss functions for\nrecommendation, tagging, ranking & other missing label applications. In Proceedings of the\n22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining,\npages 935\u2013944. ACM, 2016.\n\n[14] Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint\n\narXiv:1412.6980, 2014.\n\n10\n\n\f[15] Himanshu Jain Yashoteja Prabhu Manik Varma Kush Bhatia, Kunal Dahiya. The extreme classi-\n\ufb01cation repository: Multi-label datasets & code. http://manikvarma.org/downloads/XC/\nXMLRepository.html#Prabhu14, 2014.\n\n[16] Tomas Mikolov, Ilya Sutskever, Kai Chen, Greg S Corrado, and Jeff Dean. Distributed repre-\nsentations of words and phrases and their compositionality. In Advances in neural information\nprocessing systems, pages 3111\u20133119, 2013.\n\n[17] Priyanka Nigam, Yiwei Song, Vijai Mohan, Vihan Lakshman, Weitan Ding, Ankit Shingavi,\nChoon Hui Teo, Hao Gu, and Bing Yin. Semantic product search. In Proceedings of the 25th\nACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pages\n2876\u20132885. ACM, 2019.\n\n[18] Yashoteja Prabhu, Anil Kag, Shrutendra Harsola, Rahul Agrawal, and Manik Varma. Parabel:\nPartitioned label trees for extreme classi\ufb01cation with application to dynamic search advertising.\nIn Proceedings of the 2018 World Wide Web Conference on World Wide Web, pages 993\u20131002.\nInternational World Wide Web Conferences Steering Committee, 2018.\n\n[19] Yashoteja Prabhu and Manik Varma. Fastxml: A fast, accurate and stable tree-classi\ufb01er for\nextreme multi-label learning. In Proceedings of the 20th ACM SIGKDD international conference\non Knowledge discovery and data mining, pages 263\u2013272. ACM, 2014.\n\n[20] Noam Shazeer, Youlong Cheng, Niki Parmar, Dustin Tran, Ashish Vaswani, Penporn Koanan-\ntakool, Peter Hawkins, HyoukJoong Lee, Mingsheng Hong, Cliff Young, et al. Mesh-tensor\ufb02ow:\nDeep learning for supercomputers. In Advances in Neural Information Processing Systems,\npages 10435\u201310444, 2018.\n\n[21] Noam Shazeer, Azalia Mirhoseini, Krzysztof Maziarz, Andy Davis, Quoc Le, Geoffrey Hinton,\nand Jeff Dean. Outrageously large neural networks: The sparsely-gated mixture-of-experts\nlayer. arXiv preprint arXiv:1701.06538, 2017.\n\n[22] Ilya Sutskever, Oriol Vinyals, and Quoc V Le. Sequence to sequence learning with neural\n\nnetworks. In Advances in neural information processing systems, pages 3104\u20133112, 2014.\n\n[23] Yukihiro Tagami. Annexml: Approximate nearest neighbor search for extreme multi-label clas-\nsi\ufb01cation. In Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge\nDiscovery and Data Mining, pages 455\u2013464. ACM, 2017.\n\n[24] Kilian Weinberger, Anirban Dasgupta, John Langford, Alex Smola, and Josh Attenberg. Feature\nhashing for large scale multitask learning. In Proceedings of the 26th Annual International\nConference on Machine Learning, pages 1113\u20131120. ACM, 2009.\n\n[25] Rex Ying, Ruining He, Kaifeng Chen, Pong Eksombatchai, William L Hamilton, and Jure\nLeskovec. Graph convolutional neural networks for web-scale recommender systems.\nIn\nProceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery &\nData Mining, pages 974\u2013983. ACM, 2018.\n\n11\n\n\f", "award": [], "sourceid": 7268, "authors": [{"given_name": "Tharun Kumar Reddy", "family_name": "Medini", "institution": "Rice University"}, {"given_name": "Qixuan", "family_name": "Huang", "institution": "Rice University"}, {"given_name": "Yiqiu", "family_name": "Wang", "institution": "Massachusetts Institute of Technology"}, {"given_name": "Vijai", "family_name": "Mohan", "institution": "www.amazon.com"}, {"given_name": "Anshumali", "family_name": "Shrivastava", "institution": "Rice University"}]}