SphinxBase 5prealpha
lm_trie_quant.c
1/* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
2/* ====================================================================
3 * Copyright (c) 2015 Carnegie Mellon University. All rights
4 * reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 *
10 * 1. Redistributions of source code must retain the above copyright
11 * notice, this list of conditions and the following disclaimer.
12 *
13 * 2. Redistributions in binary form must reproduce the above copyright
14 * notice, this list of conditions and the following disclaimer in
15 * the documentation and/or other materials provided with the
16 * distribution.
17 *
18 * This work was supported in part by funding from the Defense Advanced
19 * Research Projects Agency and the National Science Foundation of the
20 * United States of America, and the CMU Sphinx Speech Consortium.
21 *
22 * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND
23 * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
24 * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
25 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
26 * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 *
34 * ====================================================================
35 *
36 */
37
38#include <math.h>
39
42#include <sphinxbase/err.h>
43
44#include "ngram_model_internal.h"
45#include "lm_trie_quant.h"
46
47#define FLOAT_INF (0x7f800000)
48
49typedef struct bins_s {
50 float *begin;
51 const float *end;
52} bins_t;
53
55 bins_t tables[NGRAM_MAX_ORDER - 1][2];
56 bins_t *longest;
57 uint8 *mem;
58 size_t mem_size;
59 uint8 prob_bits;
60 uint8 bo_bits;
61 uint32 prob_mask;
62 uint32 bo_mask;
63};
64
65static void
66bins_create(bins_t * bins, uint8 bits, float *begin)
67{
68 bins->begin = begin;
69 bins->end = bins->begin + (1ULL << bits);
70}
71
72static float *
73lower_bound(float *first, const float *last, float val)
74{
75 int count, step;
76 float *it;
77
78 count = last - first;
79 while (count > 0) {
80 it = first;
81 step = count / 2;
82 it += step;
83 if (*it < val) {
84 first = ++it;
85 count -= step + 1;
86 }
87 else {
88 count = step;
89 }
90 }
91 return first;
92}
93
94static uint64
95bins_encode(bins_t * bins, float value)
96{
97 float *above = lower_bound(bins->begin, bins->end, value);
98 if (above == bins->begin)
99 return 0;
100 if (above == bins->end)
101 return bins->end - bins->begin - 1;
102 return above - bins->begin - (value - *(above - 1) < *above - value);
103}
104
105static float
106bins_decode(bins_t * bins, size_t off)
107{
108 return bins->begin[off];
109}
110
111static size_t
112quant_size(int order)
113{
114 int prob_bits = 16;
115 int bo_bits = 16;
116 size_t longest_table = (1U << prob_bits) * sizeof(float);
117 size_t middle_table = (1U << bo_bits) * sizeof(float) + longest_table;
118 /* unigrams are currently not quantized so no need for a table. */
119 return (order - 2) * middle_table + longest_table;
120}
121
123lm_trie_quant_create(int order)
124{
125 float *start;
126 int i;
127 lm_trie_quant_t *quant =
128 (lm_trie_quant_t *) ckd_calloc(1, sizeof(*quant));
129 quant->mem_size = quant_size(order);
130 quant->mem =
131 (uint8 *) ckd_calloc(quant->mem_size, sizeof(*quant->mem));
132
133 quant->prob_bits = 16;
134 quant->bo_bits = 16;
135 quant->prob_mask = (1U << quant->prob_bits) - 1;
136 quant->bo_mask = (1U << quant->bo_bits) - 1;
137
138 start = (float *) (quant->mem);
139 for (i = 0; i < order - 2; i++) {
140 bins_create(&quant->tables[i][0], quant->prob_bits, start);
141 start += (1ULL << quant->prob_bits);
142 bins_create(&quant->tables[i][1], quant->bo_bits, start);
143 start += (1ULL << quant->bo_bits);
144 }
145 bins_create(&quant->tables[order - 2][0], quant->prob_bits, start);
146 quant->longest = &quant->tables[order - 2][0];
147 return quant;
148}
149
150
152lm_trie_quant_read_bin(FILE * fp, int order)
153{
154 int dummy;
155 lm_trie_quant_t *quant;
156
157 fread(&dummy, sizeof(dummy), 1, fp);
158 quant = lm_trie_quant_create(order);
159 fread(quant->mem, sizeof(*quant->mem), quant->mem_size, fp);
160
161 return quant;
162}
163
164void
165lm_trie_quant_write_bin(lm_trie_quant_t * quant, FILE * fp)
166{
167 /* Before it was quantization type */
168 int dummy = 1;
169 fwrite(&dummy, sizeof(dummy), 1, fp);
170 fwrite(quant->mem, sizeof(*quant->mem), quant->mem_size, fp);
171}
172
173void
174lm_trie_quant_free(lm_trie_quant_t * quant)
175{
176 if (quant->mem)
177 ckd_free(quant->mem);
178 ckd_free(quant);
179}
180
181uint8
182lm_trie_quant_msize(lm_trie_quant_t * quant)
183{
184 return 32;
185}
186
187uint8
188lm_trie_quant_lsize(lm_trie_quant_t * quant)
189{
190 return 16;
191}
192
193static int
194weights_comparator(const void *a, const void *b)
195{
196 return (int) (*(float *) a - *(float *) b);
197}
198
199static void
200make_bins(float *values, uint32 values_num, float *centers, uint32 bins)
201{
202 float *finish, *start;
203 uint32 i;
204
205 qsort(values, values_num, sizeof(*values), &weights_comparator);
206 start = values;
207 for (i = 0; i < bins; i++, centers++, start = finish) {
208 finish = values + (size_t) ((uint64) values_num * (i + 1) / bins);
209 if (finish == start) {
210 /* zero length bucket. */
211 *centers = i ? *(centers - 1) : -FLOAT_INF;
212 }
213 else {
214 float sum = 0.0f;
215 float *ptr;
216 for (ptr = start; ptr != finish; ptr++) {
217 sum += *ptr;
218 }
219 *centers = sum / (float) (finish - start);
220 }
221 }
222}
223
224void
225lm_trie_quant_train(lm_trie_quant_t * quant, int order, uint32 counts,
226 ngram_raw_t * raw_ngrams)
227{
228 float *probs;
229 float *backoffs;
230 float *centers;
231 uint32 backoff_num;
232 uint32 prob_num;
233 ngram_raw_t *raw_ngrams_end;
234
235 probs = (float *) ckd_calloc(counts, sizeof(*probs));
236 backoffs = (float *) ckd_calloc(counts, sizeof(*backoffs));
237 raw_ngrams_end = raw_ngrams + counts;
238
239 for (backoff_num = 0, prob_num = 0; raw_ngrams != raw_ngrams_end;
240 raw_ngrams++) {
241 probs[prob_num++] = raw_ngrams->prob;
242 backoffs[backoff_num++] = raw_ngrams->backoff;
243 }
244
245 make_bins(probs, prob_num, quant->tables[order - 2][0].begin,
246 1ULL << quant->prob_bits);
247 centers = quant->tables[order - 2][1].begin;
248 make_bins(backoffs, backoff_num, centers, (1ULL << quant->bo_bits));
249 ckd_free(probs);
250 ckd_free(backoffs);
251}
252
253void
254lm_trie_quant_train_prob(lm_trie_quant_t * quant, int order, uint32 counts,
255 ngram_raw_t * raw_ngrams)
256{
257 float *probs;
258 uint32 prob_num;
259 ngram_raw_t *raw_ngrams_end;
260
261 probs = (float *) ckd_calloc(counts, sizeof(*probs));
262 raw_ngrams_end = raw_ngrams + counts;
263
264 for (prob_num = 0; raw_ngrams != raw_ngrams_end; raw_ngrams++) {
265 probs[prob_num++] = raw_ngrams->prob;
266 }
267
268 make_bins(probs, prob_num, quant->tables[order - 2][0].begin,
269 1ULL << quant->prob_bits);
270 ckd_free(probs);
271}
272
273void
274lm_trie_quant_mwrite(lm_trie_quant_t * quant, bitarr_address_t address,
275 int order_minus_2, float prob, float backoff)
276{
277 bitarr_write_int57(address, quant->prob_bits + quant->bo_bits,
278 (uint64) ((bins_encode
279 (&quant->tables[order_minus_2][0],
280 prob) << quant->
281 bo_bits) | bins_encode(&quant->
282 tables
283 [order_minus_2]
284 [1],
285 backoff)));
286}
287
288void
289lm_trie_quant_lwrite(lm_trie_quant_t * quant, bitarr_address_t address,
290 float prob)
291{
292 bitarr_write_int25(address, quant->prob_bits,
293 (uint32) bins_encode(quant->longest, prob));
294}
295
296float
297lm_trie_quant_mboread(lm_trie_quant_t * quant, bitarr_address_t address,
298 int order_minus_2)
299{
300 return bins_decode(&quant->tables[order_minus_2][1],
301 bitarr_read_int25(address, quant->bo_bits,
302 quant->bo_mask));
303}
304
305float
306lm_trie_quant_mpread(lm_trie_quant_t * quant, bitarr_address_t address,
307 int order_minus_2)
308{
309 address.offset += quant->bo_bits;
310 return bins_decode(&quant->tables[order_minus_2][0],
311 bitarr_read_int25(address, quant->prob_bits,
312 quant->prob_mask));
313}
314
315float
316lm_trie_quant_lpread(lm_trie_quant_t * quant, bitarr_address_t address)
317{
318 return bins_decode(quant->longest,
319 bitarr_read_int25(address, quant->prob_bits,
320 quant->prob_mask));
321}
SPHINXBASE_EXPORT void bitarr_write_int57(bitarr_address_t address, uint8 length, uint64 value)
Write specified value into bit array.
Definition bitarr.c:87
SPHINXBASE_EXPORT uint32 bitarr_read_int25(bitarr_address_t address, uint8 length, uint32 mask)
Read uint32 value from bit array.
Definition bitarr.c:100
SPHINXBASE_EXPORT void bitarr_write_int25(bitarr_address_t address, uint8 length, uint32 value)
Write specified value into bit array.
Definition bitarr.c:112
Sphinx's memory allocation/deallocation routines.
SPHINXBASE_EXPORT void ckd_free(void *ptr)
Test and free a 1-D array.
Definition ckd_alloc.c:244
#define ckd_calloc(n, sz)
Macros to simplify the use of above functions.
Definition ckd_alloc.h:248
Implementation of logging routines.
Basic type definitions used in Sphinx.
Structure that stores address of certain value in bit array.
Definition bitarr.h:73