Skip to content

Commit c74456f

Browse files
committed
refactor step 1. the tokenizer, and all the other abstractions, are a total mess, refactoring things a bit
1 parent 1e335a4 commit c74456f

File tree

1 file changed

+81
-55
lines changed

1 file changed

+81
-55
lines changed

run.c

Lines changed: 81 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,62 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
341341
}
342342

343343
// ----------------------------------------------------------------------------
344-
// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
344+
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
345+
346+
typedef struct {
347+
char** vocab;
348+
float* vocab_scores;
349+
int vocab_size;
350+
unsigned int max_token_length;
351+
char byte_piece[2];
352+
} Tokenizer;
353+
354+
void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) {
355+
// i should have written the vocab_size into the tokenizer file... sigh
356+
t->vocab_size = vocab_size;
357+
// malloc space to hold the scores and the strings
358+
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
359+
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
360+
t->byte_piece[1] = '\0'; // null terminate the byte_piece string
361+
// read in the file
362+
FILE *file = fopen(tokenizer, "rb");
363+
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); exit(EXIT_FAILURE); }
364+
if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
365+
int len;
366+
for (int i = 0; i < vocab_size; i++) {
367+
if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
368+
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
369+
t->vocab[i] = (char *)malloc(len + 1);
370+
if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
371+
t->vocab[i][len] = '\0'; // add the string terminating token
372+
}
373+
fclose(file);
374+
}
375+
376+
void free_tokenizer(Tokenizer* t) {
377+
for (int i = 0; i < t->vocab_size; i++) {
378+
free(t->vocab[i]);
379+
}
380+
free(t->vocab);
381+
free(t->vocab_scores);
382+
}
383+
384+
char* get_piece(Tokenizer* t, int prev_token, int token) {
385+
char *piece = t->vocab[token];
386+
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
387+
if (prev_token == 1 && piece[0] == ' ') { piece++; }
388+
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
389+
unsigned char byte_val;
390+
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
391+
// ok this token is a raw byte token, careful to only print printable chars or whitespace
392+
// some of the other bytes can be various control codes, backspace, etc. => skip
393+
if (isprint(byte_val) || isspace(byte_val)) {
394+
t->byte_piece[0] = byte_val;
395+
piece = &t->byte_piece[0];
396+
}
397+
}
398+
return piece;
399+
}
345400

346401
typedef struct {
347402
char *str;
@@ -359,22 +414,23 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
359414
return res != NULL ? res->id : -1;
360415
}
361416

362-
void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) {
417+
void bpe_encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
418+
// encode the string text (input) into an upper-bound preallocated tokens[] array
363419

364420
// sort vocabulary
365-
TokenIndex *sorted_vocab = malloc(vocab_size * sizeof(TokenIndex));
366-
for (int i = 0; i < vocab_size; i++) {
367-
sorted_vocab[i].str = vocab[i];
421+
TokenIndex *sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
422+
for (int i = 0; i < t->vocab_size; i++) {
423+
sorted_vocab[i].str = t->vocab[i];
368424
sorted_vocab[i].id = i;
369425
}
370-
qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
426+
qsort(sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
371427

372428
// create a temporary buffer that will store merge candidates of always two consecutive tokens
373-
char* str_buffer = malloc((max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
429+
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
374430
size_t str_len = 0;
375431

376432
// add_dummy_prefix is true by default
377-
tokens[0] = str_lookup(" ", sorted_vocab, vocab_size);
433+
tokens[0] = str_lookup(" ", sorted_vocab, t->vocab_size);
378434
*n_tokens = 1; // the number of tokens
379435

380436
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
@@ -410,7 +466,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
410466
}
411467

412468
// ok c+1 is not a continuation byte, so we've read in a full codepoint
413-
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
469+
int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size);
414470

415471
if (id != -1) {
416472
// we found this codepoint in vocab, add it as a token
@@ -434,11 +490,11 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
434490

435491
for (int i=0; i < (*n_tokens-1); i++) {
436492
// check if we can merge the pair (tokens[i], tokens[i+1])
437-
sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]);
438-
int id = str_lookup(str_buffer, sorted_vocab, vocab_size);
439-
if (id != -1 && vocab_scores[id] > best_score) {
493+
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
494+
int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size);
495+
if (id != -1 && t->vocab_scores[id] > best_score) {
440496
// this merge pair exists in vocab! record its score and position
441-
best_score = vocab_scores[id];
497+
best_score = t->vocab_scores[id];
442498
best_id = id;
443499
best_idx = i;
444500
}
@@ -587,16 +643,16 @@ void error_usage() {
587643
int main(int argc, char *argv[]) {
588644

589645
// default inits
590-
char *checkpoint = NULL; // e.g. out/model.bin
591-
char *tokenizer = "tokenizer.bin";
646+
char *checkpoint_path = NULL; // e.g. out/model.bin
647+
char *tokenizer_path = "tokenizer.bin";
592648
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
593649
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
594650
rng_seed = 0; // seed rng with time by default
595651
int steps = 256; // number of steps to run for
596652
char *prompt = NULL; // prompt string
597653

598654
// poor man's C argparse so we can override the defaults above from the command line
599-
if (argc >= 2) { checkpoint = argv[1]; } else { error_usage(); }
655+
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
600656
for (int i = 2; i < argc; i+=2) {
601657
// do some basic validation
602658
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
@@ -608,7 +664,7 @@ int main(int argc, char *argv[]) {
608664
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
609665
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
610666
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
611-
else if (argv[i][1] == 'z') { tokenizer = argv[i + 1]; }
667+
else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
612668
else { error_usage(); }
613669
}
614670
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}
@@ -619,29 +675,14 @@ int main(int argc, char *argv[]) {
619675
int fd = 0; // file descriptor for memory mapping
620676
float* data = NULL; // memory mapped data pointer
621677
ssize_t file_size; // size of the checkpoint file in bytes
622-
read_checkpoint(checkpoint, &config, &weights, &fd, &data, &file_size);
678+
read_checkpoint(checkpoint_path, &config, &weights, &fd, &data, &file_size);
623679

624680
// right now we cannot run for more than config.seq_len steps
625681
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
626682

627683
// read in the tokenizer .bin file
628-
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
629-
float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float));
630-
unsigned int max_token_length;
631-
{
632-
FILE *file = fopen(tokenizer, "rb");
633-
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); return 1; }
634-
if (fread(&max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
635-
int len;
636-
for (int i = 0; i < config.vocab_size; i++) {
637-
if (fread(vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1;}
638-
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
639-
vocab[i] = (char *)malloc(len + 1);
640-
if (fread(vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
641-
vocab[i][len] = '\0'; // add the string terminating token
642-
}
643-
fclose(file);
644-
}
684+
Tokenizer tokenizer;
685+
build_tokenizer(tokenizer_path, &tokenizer, config.vocab_size);
645686

646687
// create and init the application RunState
647688
RunState state;
@@ -653,7 +694,7 @@ int main(int argc, char *argv[]) {
653694
int num_prompt_tokens = 0;
654695
if (prompt != NULL) {
655696
prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int));
656-
bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens);
697+
bpe_encode(&tokenizer, prompt, prompt_tokens, &num_prompt_tokens);
657698
}
658699

659700
// start the main loop
@@ -695,22 +736,9 @@ int main(int argc, char *argv[]) {
695736
// data-dependent terminating condition: the BOS (1) token delimits sequences
696737
if (next == 1) { break; }
697738

698-
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
699-
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
700-
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
701-
unsigned char byte_val;
702-
if (sscanf(token_str, "<0x%02hhX>", &byte_val) == 1) {
703-
// ok this token is a raw byte token, carefuly to only print printable chars or whitespace
704-
// some of the other bytes can be various control codes, backspace, etc. => skip
705-
if (isprint(byte_val) || isspace(byte_val)) {
706-
char byte_piece[2];
707-
byte_piece[0] = byte_val;
708-
byte_piece[1] = '\0';
709-
printf("%s", byte_piece);
710-
}
711-
} else {
712-
printf("%s", token_str);
713-
}
739+
// print the token as string, decode it with the Tokenizer object
740+
char* piece = get_piece(&tokenizer, token, next);
741+
printf("%s", piece);
714742
fflush(stdout);
715743
token = next;
716744

@@ -728,9 +756,7 @@ int main(int argc, char *argv[]) {
728756
// memory and file handles cleanup
729757
free_run_state(&state);
730758
free(probindex);
731-
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
732-
free(vocab);
733-
free(vocab_scores);
759+
free_tokenizer(&tokenizer);
734760
if (prompt_tokens != NULL) free(prompt_tokens);
735761
if (data != MAP_FAILED) munmap(data, file_size);
736762
if (fd != -1) close(fd);

0 commit comments

Comments
 (0)