@@ -341,7 +341,62 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
341
341
}
342
342
343
343
// ----------------------------------------------------------------------------
344
- // byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt
344
+ // The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
345
+
346
+ typedef struct {
347
+ char * * vocab ;
348
+ float * vocab_scores ;
349
+ int vocab_size ;
350
+ unsigned int max_token_length ;
351
+ char byte_piece [2 ];
352
+ } Tokenizer ;
353
+
354
+ void build_tokenizer (char * tokenizer , Tokenizer * t , int vocab_size ) {
355
+ // i should have written the vocab_size into the tokenizer file... sigh
356
+ t -> vocab_size = vocab_size ;
357
+ // malloc space to hold the scores and the strings
358
+ t -> vocab = (char * * )malloc (vocab_size * sizeof (char * ));
359
+ t -> vocab_scores = (float * )malloc (vocab_size * sizeof (float ));
360
+ t -> byte_piece [1 ] = '\0' ; // null terminate the byte_piece string
361
+ // read in the file
362
+ FILE * file = fopen (tokenizer , "rb" );
363
+ if (!file ) { fprintf (stderr , "couldn't load %s\n" , tokenizer ); exit (EXIT_FAILURE ); }
364
+ if (fread (& t -> max_token_length , sizeof (int ), 1 , file ) != 1 ) { fprintf (stderr , "failed read\n" ); exit (EXIT_FAILURE ); }
365
+ int len ;
366
+ for (int i = 0 ; i < vocab_size ; i ++ ) {
367
+ if (fread (t -> vocab_scores + i , sizeof (float ), 1 , file ) != 1 ) { fprintf (stderr , "failed read\n" ); exit (EXIT_FAILURE );}
368
+ if (fread (& len , sizeof (int ), 1 , file ) != 1 ) { fprintf (stderr , "failed read\n" ); exit (EXIT_FAILURE ); }
369
+ t -> vocab [i ] = (char * )malloc (len + 1 );
370
+ if (fread (t -> vocab [i ], len , 1 , file ) != 1 ) { fprintf (stderr , "failed read\n" ); exit (EXIT_FAILURE ); }
371
+ t -> vocab [i ][len ] = '\0' ; // add the string terminating token
372
+ }
373
+ fclose (file );
374
+ }
375
+
376
+ void free_tokenizer (Tokenizer * t ) {
377
+ for (int i = 0 ; i < t -> vocab_size ; i ++ ) {
378
+ free (t -> vocab [i ]);
379
+ }
380
+ free (t -> vocab );
381
+ free (t -> vocab_scores );
382
+ }
383
+
384
+ char * get_piece (Tokenizer * t , int prev_token , int token ) {
385
+ char * piece = t -> vocab [token ];
386
+ // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
387
+ if (prev_token == 1 && piece [0 ] == ' ' ) { piece ++ ; }
388
+ // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
389
+ unsigned char byte_val ;
390
+ if (sscanf (piece , "<0x%02hhX>" , & byte_val ) == 1 ) {
391
+ // ok this token is a raw byte token, careful to only print printable chars or whitespace
392
+ // some of the other bytes can be various control codes, backspace, etc. => skip
393
+ if (isprint (byte_val ) || isspace (byte_val )) {
394
+ t -> byte_piece [0 ] = byte_val ;
395
+ piece = & t -> byte_piece [0 ];
396
+ }
397
+ }
398
+ return piece ;
399
+ }
345
400
346
401
typedef struct {
347
402
char * str ;
@@ -359,22 +414,23 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
359
414
return res != NULL ? res -> id : -1 ;
360
415
}
361
416
362
- void bpe_encode (char * text , char * * vocab , float * vocab_scores , int vocab_size , unsigned int max_token_length , int * tokens , int * n_tokens ) {
417
+ void bpe_encode (Tokenizer * t , char * text , int * tokens , int * n_tokens ) {
418
+ // encode the string text (input) into an upper-bound preallocated tokens[] array
363
419
364
420
// sort vocabulary
365
- TokenIndex * sorted_vocab = malloc (vocab_size * sizeof (TokenIndex ));
366
- for (int i = 0 ; i < vocab_size ; i ++ ) {
367
- sorted_vocab [i ].str = vocab [i ];
421
+ TokenIndex * sorted_vocab = malloc (t -> vocab_size * sizeof (TokenIndex ));
422
+ for (int i = 0 ; i < t -> vocab_size ; i ++ ) {
423
+ sorted_vocab [i ].str = t -> vocab [i ];
368
424
sorted_vocab [i ].id = i ;
369
425
}
370
- qsort (sorted_vocab , vocab_size , sizeof (TokenIndex ), compare_tokens );
426
+ qsort (sorted_vocab , t -> vocab_size , sizeof (TokenIndex ), compare_tokens );
371
427
372
428
// create a temporary buffer that will store merge candidates of always two consecutive tokens
373
- char * str_buffer = malloc ((max_token_length * 2 + 1 + 2 ) * sizeof (char )); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
429
+ char * str_buffer = malloc ((t -> max_token_length * 2 + 1 + 2 ) * sizeof (char )); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
374
430
size_t str_len = 0 ;
375
431
376
432
// add_dummy_prefix is true by default
377
- tokens [0 ] = str_lookup (" " , sorted_vocab , vocab_size );
433
+ tokens [0 ] = str_lookup (" " , sorted_vocab , t -> vocab_size );
378
434
* n_tokens = 1 ; // the number of tokens
379
435
380
436
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
@@ -410,7 +466,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
410
466
}
411
467
412
468
// ok c+1 is not a continuation byte, so we've read in a full codepoint
413
- int id = str_lookup (str_buffer , sorted_vocab , vocab_size );
469
+ int id = str_lookup (str_buffer , sorted_vocab , t -> vocab_size );
414
470
415
471
if (id != -1 ) {
416
472
// we found this codepoint in vocab, add it as a token
@@ -434,11 +490,11 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
434
490
435
491
for (int i = 0 ; i < (* n_tokens - 1 ); i ++ ) {
436
492
// check if we can merge the pair (tokens[i], tokens[i+1])
437
- sprintf (str_buffer , "%s%s" , vocab [tokens [i ]], vocab [tokens [i + 1 ]]);
438
- int id = str_lookup (str_buffer , sorted_vocab , vocab_size );
439
- if (id != -1 && vocab_scores [id ] > best_score ) {
493
+ sprintf (str_buffer , "%s%s" , t -> vocab [tokens [i ]], t -> vocab [tokens [i + 1 ]]);
494
+ int id = str_lookup (str_buffer , sorted_vocab , t -> vocab_size );
495
+ if (id != -1 && t -> vocab_scores [id ] > best_score ) {
440
496
// this merge pair exists in vocab! record its score and position
441
- best_score = vocab_scores [id ];
497
+ best_score = t -> vocab_scores [id ];
442
498
best_id = id ;
443
499
best_idx = i ;
444
500
}
@@ -587,16 +643,16 @@ void error_usage() {
587
643
int main (int argc , char * argv []) {
588
644
589
645
// default inits
590
- char * checkpoint = NULL ; // e.g. out/model.bin
591
- char * tokenizer = "tokenizer.bin" ;
646
+ char * checkpoint_path = NULL ; // e.g. out/model.bin
647
+ char * tokenizer_path = "tokenizer.bin" ;
592
648
float temperature = 1.0f ; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
593
649
float topp = 0.9f ; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
594
650
rng_seed = 0 ; // seed rng with time by default
595
651
int steps = 256 ; // number of steps to run for
596
652
char * prompt = NULL ; // prompt string
597
653
598
654
// poor man's C argparse so we can override the defaults above from the command line
599
- if (argc >= 2 ) { checkpoint = argv [1 ]; } else { error_usage (); }
655
+ if (argc >= 2 ) { checkpoint_path = argv [1 ]; } else { error_usage (); }
600
656
for (int i = 2 ; i < argc ; i += 2 ) {
601
657
// do some basic validation
602
658
if (i + 1 >= argc ) { error_usage (); } // must have arg after flag
@@ -608,7 +664,7 @@ int main(int argc, char *argv[]) {
608
664
else if (argv [i ][1 ] == 's' ) { rng_seed = atoi (argv [i + 1 ]); }
609
665
else if (argv [i ][1 ] == 'n' ) { steps = atoi (argv [i + 1 ]); }
610
666
else if (argv [i ][1 ] == 'i' ) { prompt = argv [i + 1 ]; }
611
- else if (argv [i ][1 ] == 'z' ) { tokenizer = argv [i + 1 ]; }
667
+ else if (argv [i ][1 ] == 'z' ) { tokenizer_path = argv [i + 1 ]; }
612
668
else { error_usage (); }
613
669
}
614
670
if (rng_seed == 0 ) { rng_seed = (unsigned int )time (NULL );}
@@ -619,29 +675,14 @@ int main(int argc, char *argv[]) {
619
675
int fd = 0 ; // file descriptor for memory mapping
620
676
float * data = NULL ; // memory mapped data pointer
621
677
ssize_t file_size ; // size of the checkpoint file in bytes
622
- read_checkpoint (checkpoint , & config , & weights , & fd , & data , & file_size );
678
+ read_checkpoint (checkpoint_path , & config , & weights , & fd , & data , & file_size );
623
679
624
680
// right now we cannot run for more than config.seq_len steps
625
681
if (steps <= 0 || steps > config .seq_len ) { steps = config .seq_len ; }
626
682
627
683
// read in the tokenizer .bin file
628
- char * * vocab = (char * * )malloc (config .vocab_size * sizeof (char * ));
629
- float * vocab_scores = (float * )malloc (config .vocab_size * sizeof (float ));
630
- unsigned int max_token_length ;
631
- {
632
- FILE * file = fopen (tokenizer , "rb" );
633
- if (!file ) { fprintf (stderr , "couldn't load %s\n" , tokenizer ); return 1 ; }
634
- if (fread (& max_token_length , sizeof (int ), 1 , file ) != 1 ) { fprintf (stderr , "failed read\n" ); return 1 ; }
635
- int len ;
636
- for (int i = 0 ; i < config .vocab_size ; i ++ ) {
637
- if (fread (vocab_scores + i , sizeof (float ), 1 , file ) != 1 ) { fprintf (stderr , "failed read\n" ); return 1 ;}
638
- if (fread (& len , sizeof (int ), 1 , file ) != 1 ) { fprintf (stderr , "failed read\n" ); return 1 ; }
639
- vocab [i ] = (char * )malloc (len + 1 );
640
- if (fread (vocab [i ], len , 1 , file ) != 1 ) { fprintf (stderr , "failed read\n" ); return 1 ; }
641
- vocab [i ][len ] = '\0' ; // add the string terminating token
642
- }
643
- fclose (file );
644
- }
684
+ Tokenizer tokenizer ;
685
+ build_tokenizer (tokenizer_path , & tokenizer , config .vocab_size );
645
686
646
687
// create and init the application RunState
647
688
RunState state ;
@@ -653,7 +694,7 @@ int main(int argc, char *argv[]) {
653
694
int num_prompt_tokens = 0 ;
654
695
if (prompt != NULL ) {
655
696
prompt_tokens = (int * )malloc ((strlen (prompt )+ 1 ) * sizeof (int ));
656
- bpe_encode (prompt , vocab , vocab_scores , config . vocab_size , max_token_length , prompt_tokens , & num_prompt_tokens );
697
+ bpe_encode (& tokenizer , prompt , prompt_tokens , & num_prompt_tokens );
657
698
}
658
699
659
700
// start the main loop
@@ -695,22 +736,9 @@ int main(int argc, char *argv[]) {
695
736
// data-dependent terminating condition: the BOS (1) token delimits sequences
696
737
if (next == 1 ) { break ; }
697
738
698
- // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
699
- char * token_str = (token == 1 && vocab [next ][0 ] == ' ' ) ? vocab [next ]+ 1 : vocab [next ];
700
- // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
701
- unsigned char byte_val ;
702
- if (sscanf (token_str , "<0x%02hhX>" , & byte_val ) == 1 ) {
703
- // ok this token is a raw byte token, carefuly to only print printable chars or whitespace
704
- // some of the other bytes can be various control codes, backspace, etc. => skip
705
- if (isprint (byte_val ) || isspace (byte_val )) {
706
- char byte_piece [2 ];
707
- byte_piece [0 ] = byte_val ;
708
- byte_piece [1 ] = '\0' ;
709
- printf ("%s" , byte_piece );
710
- }
711
- } else {
712
- printf ("%s" , token_str );
713
- }
739
+ // print the token as string, decode it with the Tokenizer object
740
+ char * piece = get_piece (& tokenizer , token , next );
741
+ printf ("%s" , piece );
714
742
fflush (stdout );
715
743
token = next ;
716
744
@@ -728,9 +756,7 @@ int main(int argc, char *argv[]) {
728
756
// memory and file handles cleanup
729
757
free_run_state (& state );
730
758
free (probindex );
731
- for (int i = 0 ; i < config .vocab_size ; i ++ ) { free (vocab [i ]); }
732
- free (vocab );
733
- free (vocab_scores );
759
+ free_tokenizer (& tokenizer );
734
760
if (prompt_tokens != NULL ) free (prompt_tokens );
735
761
if (data != MAP_FAILED ) munmap (data , file_size );
736
762
if (fd != -1 ) close (fd );
0 commit comments