Replace plain-memory ordered array by binary tree in ts_stat() function.
authorTeodor Sigaev
Mon, 17 Nov 2008 12:17:09 +0000 (12:17 +0000)
committerTeodor Sigaev
Mon, 17 Nov 2008 12:17:09 +0000 (12:17 +0000)
Performance is increased from 50% up to 10^3 times depending on data.

src/backend/utils/adt/tsvector_op.c

index bc342839d99d174d90bd32ee31d0a35391818da1..5cb4f4f1d9db5c9182a35b5145b32912d69585f3 100644 (file)
@@ -7,7 +7,7 @@
  *
  *
  * IDENTIFICATION
- *   $PostgreSQL: pgsql/src/backend/utils/adt/tsvector_op.c,v 1.17 2008/11/10 21:49:16 alvherre Exp $
+ *   $PostgreSQL: pgsql/src/backend/utils/adt/tsvector_op.c,v 1.18 2008/11/17 12:17:09 teodor Exp $
  *
  *-------------------------------------------------------------------------
  */
@@ -34,34 +34,33 @@ typedef struct
    char       *operand;
 } CHKVAL;
 
-typedef struct
-{
-   uint32      cur;
-   TSVector    stat;
-} StatStorage;
 
-typedef struct
+typedef struct StatEntry
 {
-   uint32      len;
-   uint32      pos;
-   uint32      ndoc;
+   uint32      ndoc; /* zero indicates that we already was here while
+                        walking throug the tree */
    uint32      nentry;
+   struct StatEntry *left;
+   struct StatEntry *right;
+   uint32      lenlexeme;
+   char        lexeme[1];
 } StatEntry;
 
+#define STATENTRYHDRSZ (offsetof(StatEntry, lexeme))
+
 typedef struct
 {
-   int32       vl_len_;        /* varlena header (do not touch directly!) */
-   int4        size;
    int4        weight;
-   char        data[1];
-} tsstat;
 
-#define STATHDRSIZE (sizeof(int4) * 4)
-#define CALCSTATSIZE(x, lenstr) ( (x) * sizeof(StatEntry) + STATHDRSIZE + (lenstr) )
-#define STATPTR(x) ( (StatEntry*) ( (char*)(x) + STATHDRSIZE ) )
-#define STATSTRPTR(x)  ( (char*)(x) + STATHDRSIZE + ( sizeof(StatEntry) * ((TSVector)(x))->size ) )
-#define STATSTRSIZE(x) ( VARSIZE((TSVector)(x)) - STATHDRSIZE - ( sizeof(StatEntry) * ((TSVector)(x))->size ) )
+   uint32      maxdepth;
+   
+   StatEntry   **stack;
+   uint32      stackpos;
 
+   StatEntry*  root;
+} TSVectorStat;
+
+#define STATHDRSIZE (offsetof(TSVectorStat, data))
 
 static Datum tsvector_update_trigger(PG_FUNCTION_ARGS, bool config_column);
 
@@ -801,92 +800,95 @@ check_weight(TSVector txt, WordEntry *wptr, int8 weight)
    return num;
 }
 
-#define compareStatWord(a,e,s,t) \
-   tsCompareString(STATSTRPTR(s) + (a)->pos, (a)->len, \
+#define compareStatWord(a,e,t)                             \
+   tsCompareString((a)->lexeme, (a)->lenlexeme,        \
                    STRPTR(t) + (e)->pos, (e)->len,     \
                    false)
 
-typedef struct WordEntryMark
+static void
+insertStatEntry(MemoryContext persistentContext, TSVectorStat *stat, TSVector txt, uint32 off)
 {
-   WordEntry   *newentry;
-   StatEntry   *pos;
-} WordEntryMark;
+   WordEntry   *we = ARRPTR(txt) + off;
+   StatEntry   *node = stat->root, 
+               *pnode=NULL;
+   int         n,
+               res;
+   uint32      depth=1;
+
+   if (stat->weight == 0) 
+       n = (we->haspos) ? POSDATALEN(txt, we) : 1;
+   else
+       n = (we->haspos) ? check_weight(txt, we, stat->weight) : 0;
 
-static tsstat *
-formstat(tsstat *stat, TSVector txt, List *entries)
-{
-   tsstat         *newstat;
-   uint32          totallen,
-                   nentry,
-                   len = list_length(entries);
-   uint32          slen = 0;
-   WordEntry      *ptr;
-   char           *curptr;
-   StatEntry      *sptr,
-                  *nptr;
-   ListCell       *entry;
-   StatEntry      *PosSE = STATPTR(stat),
-                  *prevPosSE;
-   WordEntryMark  *mark;
-
-   foreach( entry, entries )
-   {
-       mark = (WordEntryMark*)lfirst(entry);
-       slen += mark->newentry->len;
-   }
+   if ( n == 0 )
+       return; /* nothing to insert */
 
-   nentry = stat->size + len;
-   slen += STATSTRSIZE(stat);
-   totallen = CALCSTATSIZE(nentry, slen);
-   newstat = palloc(totallen);
-   SET_VARSIZE(newstat, totallen);
-   newstat->weight = stat->weight;
-   newstat->size = nentry;
+   while( node ) 
+   {
+       res = compareStatWord(node, we, txt);
 
-   memcpy(STATSTRPTR(newstat), STATSTRPTR(stat), STATSTRSIZE(stat));
-   curptr = STATSTRPTR(newstat) + STATSTRSIZE(stat);
+       if (res == 0)
+       {
+           break;
+       }
+       else
+       {
+           pnode = node;
+           node = ( res < 0 ) ? node->left : node->right;
+       }
+       depth++;
+   }
 
-   sptr = STATPTR(stat);
-   nptr = STATPTR(newstat);
+   if (depth > stat->maxdepth)
+       stat->maxdepth = depth;
 
-   foreach(entry, entries)
+   if (node == NULL)
    {
-       prevPosSE = PosSE;
-
-       mark = (WordEntryMark*)lfirst(entry);
-       ptr  = mark->newentry;
-       PosSE = mark->pos;
-
-       /*
-        * Copy missed entries 
-        */
-       if ( PosSE > prevPosSE )
+       node = MemoryContextAlloc(persistentContext, STATENTRYHDRSZ + we->len );
+       node->left = node->right = NULL;
+       node->ndoc = 1;
+       node->nentry = n;
+       node->lenlexeme = we->len;
+       memcpy(node->lexeme, STRPTR(txt) + we->pos, node->lenlexeme);
+
+       if ( pnode==NULL )
        {
-           memcpy( nptr, prevPosSE, sizeof(StatEntry) * (PosSE-prevPosSE) );
-           nptr += PosSE-prevPosSE;
+           stat->root = node;
        }
-
-       /*
-        * Copy new entry
-        */
-       if (ptr->haspos)
-           nptr->nentry = (stat->weight) ? check_weight(txt, ptr, stat->weight) : POSDATALEN(txt, ptr);
        else
-           nptr->nentry = 1;
-       nptr->ndoc = 1;
-       nptr->len = ptr->len;
-       memcpy(curptr, STRPTR(txt) + ptr->pos, nptr->len);
-       nptr->pos = curptr - STATSTRPTR(newstat);
-       curptr += nptr->len;
-       nptr++;
-
-       pfree(mark);
+       {
+           if (res < 0)
+               pnode->left = node;
+           else
+               pnode->right = node;
+       }
+           
    }
+   else
+   {
+       node->ndoc++;
+       node->nentry += n;
+   }
+}
 
-   if ( PosSE < (StatEntry *) STATSTRPTR(stat) )
-       memcpy(nptr, PosSE, sizeof(StatEntry) * (stat->size - (PosSE - STATPTR(stat))));
-
-   return newstat;
+static void
+chooseNextStatEntry(MemoryContext persistentContext, TSVectorStat *stat, TSVector txt, 
+           uint32 low, uint32 high, uint32 offset)
+{
+   uint32      pos;
+   uint32      middle = (low + high) >> 1;
+
+   pos = (low + middle) >> 1;
+   if (low != middle && pos >= offset && pos - offset < txt->size)
+       insertStatEntry( persistentContext, stat, txt, pos - offset );
+   pos = (high + middle + 1) >> 1;
+   if (middle + 1 != high && pos >= offset && pos - offset < txt->size)
+       insertStatEntry( persistentContext, stat, txt, pos - offset );
+
+   if (low != middle)
+       chooseNextStatEntry(persistentContext, stat, txt, low, middle, offset);
+   if (high != middle + 1)
+       chooseNextStatEntry(persistentContext, stat, txt, middle + 1, high, offset);
 }
 
 /*
@@ -901,115 +903,69 @@ formstat(tsstat *stat, TSVector txt, List *entries)
  * where vector_column is a tsvector-type column in vector_table.
  */
 
-static tsstat *
-ts_accum(tsstat *stat, Datum data)
+static TSVectorStat *
+ts_accum(MemoryContext persistentContext, TSVectorStat *stat, Datum data)
 {
-   tsstat     *newstat;
-   TSVector    txt = DatumGetTSVector(data);
-   StatEntry  *sptr;
-   WordEntry  *wptr;
-   int         n = 0;
-   List       *newentries=NIL;
-   StatEntry  *StopLow;
+   TSVector        txt = DatumGetTSVector(data);
+   uint32          i,
+                   nbit = 0,
+                   offset;
 
    if (stat == NULL)
-   {                           /* Init in first */
-       stat = palloc(STATHDRSIZE);
-       SET_VARSIZE(stat, STATHDRSIZE);
-       stat->size = 0;
-       stat->weight = 0;
+   {   /* Init in first */
+       stat = MemoryContextAllocZero(persistentContext, sizeof(TSVectorStat));
+       stat->maxdepth = 1;
    }
 
    /* simple check of correctness */
    if (txt == NULL || txt->size == 0)
    {
-       if (txt != (TSVector) DatumGetPointer(data))
+       if (txt && txt != (TSVector) DatumGetPointer(data))
            pfree(txt);
        return stat;
    }
 
-   sptr = STATPTR(stat);
-   wptr = ARRPTR(txt);
-   StopLow = STATPTR(stat);
-
-   while (wptr - ARRPTR(txt) < txt->size)
-   {
-       StatEntry  *StopHigh = (StatEntry *) STATSTRPTR(stat);
-       int         cmp;
-
-       /*
-        * We do not set StopLow to begin of array because tsvector is ordered 
-        * with the sames rule, so we can search from last stopped position
-        */
-
-       while (StopLow < StopHigh)
-       {
-           sptr = StopLow + (StopHigh - StopLow) / 2;
-           cmp = compareStatWord(sptr, wptr, stat, txt);
-           if (cmp == 0)
-           {
-               if (stat->weight == 0)
-               {
-                   sptr->ndoc++;
-                   sptr->nentry += (wptr->haspos) ? POSDATALEN(txt, wptr) : 1;
-               }
-               else if (wptr->haspos && (n = check_weight(txt, wptr, stat->weight)) != 0)
-               {
-                   sptr->ndoc++;
-                   sptr->nentry += n;
-               }
-               break;
-           }
-           else if (cmp < 0)
-               StopLow = sptr + 1;
-           else
-               StopHigh = sptr;
-       }
-
-       if (StopLow >= StopHigh)
-       {                   /* not found */
-           if (stat->weight == 0 || check_weight(txt, wptr, stat->weight) != 0)
-           {
-               WordEntryMark *mark = (WordEntryMark*)palloc(sizeof(WordEntryMark));
+   i = txt->size - 1;
+   for (; i > 0; i >>= 1)
+       nbit++;
 
-               mark->newentry = wptr;
-               mark->pos = StopLow;
-               newentries = lappend( newentries, mark );
+   nbit = 1 << nbit;
+   offset = (nbit - txt->size) / 2;
 
-           }
-       }
-       wptr++;
-   }
+   insertStatEntry( persistentContext, stat, txt, (nbit >> 1) - offset );
+   chooseNextStatEntry(persistentContext, stat, txt, 0, nbit, offset);
 
-   if (list_length(newentries) == 0)
-   {                           /* no new words */
-       if (txt != (TSVector) DatumGetPointer(data))
-           pfree(txt);
-       return stat;
-   }
-
-   newstat = formstat(stat, txt, newentries);
-   list_free(newentries);
-
-   if (txt != (TSVector) DatumGetPointer(data))
-       pfree(txt);
-   return newstat;
+   return stat;
 }
 
 static void
 ts_setup_firstcall(FunctionCallInfo fcinfo, FuncCallContext *funcctx,
-                  tsstat *stat)
+                  TSVectorStat *stat)
 {
-   TupleDesc   tupdesc;
-   MemoryContext oldcontext;
-   StatStorage *st;
+   TupleDesc       tupdesc;
+   MemoryContext   oldcontext;
+   StatEntry       *node;
+
+   funcctx->user_fctx = (void *) stat;
 
    oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
-   st = palloc(sizeof(StatStorage));
-   st->cur = 0;
-   st->stat = palloc(VARSIZE(stat));
-   memcpy(st->stat, stat, VARSIZE(stat));
-   funcctx->user_fctx = (void *) st;
+
+   stat->stack = palloc0(sizeof(StatEntry *) * (stat->maxdepth + 1));
+   stat->stackpos = 0; 
+
+   node = stat->root;
+   /* find leftmost value */
+   for (;;)
+   {
+       stat->stack[ stat->stackpos ] = node;
+       if (node->left)
+       {
+           stat->stackpos++;
+           node = node->left;
+       }
+       else
+           break;
+   }
 
    tupdesc = CreateTemplateTupleDesc(3, false);
    TupleDescInitEntry(tupdesc, (AttrNumber) 1, "word",
@@ -1024,26 +980,72 @@ ts_setup_firstcall(FunctionCallInfo fcinfo, FuncCallContext *funcctx,
    MemoryContextSwitchTo(oldcontext);
 }
 
+static StatEntry *
+walkStatEntryTree(TSVectorStat *stat) 
+{
+   StatEntry   *node = stat->stack[ stat->stackpos ];
+
+   if ( node == NULL )
+       return NULL;
+
+   if ( node->ndoc != 0 )
+   {
+       /* return entry itself: we already was at left sublink */
+       return node;
+   }
+   else if (node->right && node->right != stat->stack[stat->stackpos + 1])
+   {
+       /* go on right sublink */
+       stat->stackpos++;
+       node = node->right;
+
+       /* find most-left value */
+       for (;;)
+       {
+           stat->stack[stat->stackpos] = node;
+           if (node->left)
+           {
+               stat->stackpos++;
+               node = node->left;
+           }
+           else
+               break;
+       }
+   }
+   else
+   {
+       /* we already return all left subtree, itself and  right subtree */
+       if (stat->stackpos == 0)
+           return NULL;
+
+       stat->stackpos--;
+       return walkStatEntryTree(stat);
+   }
+
+   return node;
+}
 
 static Datum
 ts_process_call(FuncCallContext *funcctx)
 {
-   StatStorage *st;
+   TSVectorStat    *st;
+   StatEntry       *entry;
+
+   st = (TSVectorStat *) funcctx->user_fctx;
 
-   st = (StatStorage *) funcctx->user_fctx;
+   entry = walkStatEntryTree(st);
 
-   if (st->cur < st->stat->size)
+   if (entry != NULL)
    {
        Datum       result;
        char       *values[3];
        char        ndoc[16];
        char        nentry[16];
-       StatEntry  *entry = STATPTR(st->stat) + st->cur;
        HeapTuple   tuple;
 
-       values[0] = palloc(entry->len + 1);
-       memcpy(values[0], STATSTRPTR(st->stat) + entry->pos, entry->len);
-       (values[0])[entry->len] = '\0';
+       values[0] = palloc(entry->lenlexeme + 1);
+       memcpy(values[0], entry->lexeme, entry->lenlexeme);
+       (values[0])[entry->lenlexeme] = '\0';
        sprintf(ndoc, "%d", entry->ndoc);
        values[1] = ndoc;
        sprintf(nentry, "%d", entry->nentry);
@@ -1053,25 +1055,22 @@ ts_process_call(FuncCallContext *funcctx)
        result = HeapTupleGetDatum(tuple);
 
        pfree(values[0]);
-       st->cur++;
+
+       /* mark entry as already visited */
+       entry->ndoc = 0;
+
        return result;
    }
-   else
-   {
-       pfree(st->stat);
-       pfree(st);
-   }
 
    return (Datum) 0;
 }
 
-static tsstat *
-ts_stat_sql(text *txt, text *ws)
+static TSVectorStat *
+ts_stat_sql(MemoryContext persistentContext, text *txt, text *ws)
 {
    char       *query = text_to_cstring(txt);
    int         i;
-   tsstat     *newstat,
-              *stat;
+   TSVectorStat *stat;
    bool        isnull;
    Portal      portal;
    SPIPlanPtr  plan;
@@ -1094,10 +1093,8 @@ ts_stat_sql(text *txt, text *ws)
                (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
                 errmsg("ts_stat query must return one tsvector column")));
 
-   stat = palloc(STATHDRSIZE);
-   SET_VARSIZE(stat, STATHDRSIZE);
-   stat->size = 0;
-   stat->weight = 0;
+   stat = MemoryContextAllocZero(persistentContext, sizeof(TSVectorStat));
+   stat->maxdepth = 1;
 
    if (ws)
    {
@@ -1141,12 +1138,7 @@ ts_stat_sql(text *txt, text *ws)
            Datum       data = SPI_getbinval(SPI_tuptable->vals[i], SPI_tuptable->tupdesc, 1, &isnull);
 
            if (!isnull)
-           {
-               newstat = ts_accum(stat, data);
-               if (stat != newstat && stat)
-                   pfree(stat);
-               stat = newstat;
-           }
+               stat = ts_accum(persistentContext, stat, data);
        }
 
        SPI_freetuptable(SPI_tuptable);
@@ -1169,12 +1161,12 @@ ts_stat1(PG_FUNCTION_ARGS)
 
    if (SRF_IS_FIRSTCALL())
    {
-       tsstat     *stat;
+       TSVectorStat       *stat;
        text       *txt = PG_GETARG_TEXT_P(0);
 
        funcctx = SRF_FIRSTCALL_INIT();
        SPI_connect();
-       stat = ts_stat_sql(txt, NULL);
+       stat = ts_stat_sql(funcctx->multi_call_memory_ctx, txt, NULL);
        PG_FREE_IF_COPY(txt, 0);
        ts_setup_firstcall(fcinfo, funcctx, stat);
        SPI_finish();
@@ -1194,13 +1186,13 @@ ts_stat2(PG_FUNCTION_ARGS)
 
    if (SRF_IS_FIRSTCALL())
    {
-       tsstat     *stat;
+       TSVectorStat       *stat;
        text       *txt = PG_GETARG_TEXT_P(0);
        text       *ws = PG_GETARG_TEXT_P(1);
 
        funcctx = SRF_FIRSTCALL_INIT();
        SPI_connect();
-       stat = ts_stat_sql(txt, ws);
+       stat = ts_stat_sql(funcctx->multi_call_memory_ctx, txt, ws);
        PG_FREE_IF_COPY(txt, 0);
        PG_FREE_IF_COPY(ws, 1);
        ts_setup_firstcall(fcinfo, funcctx, stat);