Fix handling of NULL distances in KNN-GiST
authorAlexander Korotkov
Sun, 8 Sep 2019 18:13:40 +0000 (21:13 +0300)
committerAlexander Korotkov
Sun, 8 Sep 2019 18:46:58 +0000 (21:46 +0300)
In order to implement NULL LAST semantic GiST previously assumed distance to
the NULL value to be Inf.  However, our distance functions can return Inf and
NaN for non-null values.  In such cases, NULL LAST semantic appears to be
broken.  This commit fixes that by introducing separate array of null flags for
distances.

Backpatch to all supported versions.

Discussion: https://postgr.es/m/CAPpHfdsNvNdA0DBS%2BwMpFrgwT6C3-q50sFVGLSiuWnV3FqOJuQ%40mail.gmail.com
Author: Alexander Korotkov
Backpatch-through: 9.4

src/backend/access/gist/gistget.c
src/backend/access/gist/gistscan.c
src/include/access/gist_private.h

index c4e8a3b9131620e0541da7067a1b13a0f3cd51f1..d86fe9c64be856a2e1513b9866151b4f8b7346bc 100644 (file)
@@ -112,8 +112,9 @@ gistkillitems(IndexScanDesc scan)
  * Similarly, *recheck_distances_p is set to indicate whether the distances
  * need to be rechecked, and it is also ignored for non-leaf entries.
  *
- * If we are doing an ordered scan, so->distances[] is filled with distance
- * data from the distance() functions before returning success.
+ * If we are doing an ordered scan, so->distancesValues[] and
+ * so->distancesNulls[] is filled with distance data from the distance()
+ * functions before returning success.
  *
  * We must decompress the key in the IndexTuple before passing it to the
  * sk_funcs (which actually are the opclass Consistent or Distance methods).
@@ -134,7 +135,8 @@ gistindex_keytest(IndexScanDesc scan,
    GISTSTATE  *giststate = so->giststate;
    ScanKey     key = scan->keyData;
    int         keySize = scan->numberOfKeys;
-   double     *distance_p;
+   double     *distance_value_p;
+   bool       *distance_null_p;
    Relation    r = scan->indexRelation;
 
    *recheck_p = false;
@@ -152,7 +154,10 @@ gistindex_keytest(IndexScanDesc scan,
        if (GistPageIsLeaf(page))   /* shouldn't happen */
            elog(ERROR, "invalid GiST tuple found on leaf page");
        for (i = 0; i < scan->numberOfOrderBys; i++)
-           so->distances[i] = -get_float8_infinity();
+       {
+           so->distanceValues[i] = -get_float8_infinity();
+           so->distanceNulls[i] = false;
+       }
        return true;
    }
 
@@ -235,7 +240,8 @@ gistindex_keytest(IndexScanDesc scan,
 
    /* OK, it passes --- now let's compute the distances */
    key = scan->orderByData;
-   distance_p = so->distances;
+   distance_value_p = so->distanceValues;
+   distance_null_p = so->distanceNulls;
    keySize = scan->numberOfOrderBys;
    while (keySize > 0)
    {
@@ -249,8 +255,9 @@ gistindex_keytest(IndexScanDesc scan,
 
        if ((key->sk_flags & SK_ISNULL) || isNull)
        {
-           /* Assume distance computes as null and sorts to the end */
-           *distance_p = get_float8_infinity();
+           /* Assume distance computes as null */
+           *distance_value_p = 0.0;
+           *distance_null_p = true;
        }
        else
        {
@@ -287,11 +294,13 @@ gistindex_keytest(IndexScanDesc scan,
                                     ObjectIdGetDatum(key->sk_subtype),
                                     PointerGetDatum(&recheck));
            *recheck_distances_p |= recheck;
-           *distance_p = DatumGetFloat8(dist);
+           *distance_value_p = DatumGetFloat8(dist);
+           *distance_null_p = false;
        }
 
        key++;
-       distance_p++;
+       distance_value_p++;
+       distance_null_p++;
        keySize--;
    }
 
@@ -304,7 +313,8 @@ gistindex_keytest(IndexScanDesc scan,
  *
  * scan: index scan we are executing
  * pageItem: search queue item identifying an index page to scan
- * myDistances: distances array associated with pageItem, or NULL at the root
+ * myDistanceValues: distances array associated with pageItem, or NULL at the root
+ * myDistanceNulls: null flags for myDistanceValues array, or NULL at the root
  * tbm: if not NULL, gistgetbitmap's output bitmap
  * ntids: if not NULL, gistgetbitmap's output tuple counter
  *
@@ -321,7 +331,8 @@ gistindex_keytest(IndexScanDesc scan,
  * sibling will be processed next.
  */
 static void
-gistScanPage(IndexScanDesc scan, GISTSearchItem *pageItem, double *myDistances,
+gistScanPage(IndexScanDesc scan, GISTSearchItem *pageItem,
+            double *myDistanceValues, bool *myDistanceNulls,
             TIDBitmap *tbm, int64 *ntids)
 {
    GISTScanOpaque so = (GISTScanOpaque) scan->opaque;
@@ -359,7 +370,7 @@ gistScanPage(IndexScanDesc scan, GISTSearchItem *pageItem, double *myDistances,
        GISTSearchItem *item;
 
        /* This can't happen when starting at the root */
-       Assert(myDistances != NULL);
+       Assert(myDistanceValues != NULL && myDistanceNulls != NULL);
 
        oldcxt = MemoryContextSwitchTo(so->queueCxt);
 
@@ -369,8 +380,10 @@ gistScanPage(IndexScanDesc scan, GISTSearchItem *pageItem, double *myDistances,
        item->data.parentlsn = pageItem->data.parentlsn;
 
        /* Insert it into the queue using same distances as for this page */
-       memcpy(item->distances, myDistances,
-              sizeof(double) * scan->numberOfOrderBys);
+       memcpy(GISTSearchItemDistanceValues(item, scan->numberOfOrderBys),
+              myDistanceValues, sizeof(double) * scan->numberOfOrderBys);
+       memcpy(GISTSearchItemDistanceNulls(item, scan->numberOfOrderBys),
+              myDistanceNulls, sizeof(bool) * scan->numberOfOrderBys);
 
        pairingheap_add(so->queue, &item->phNode);
 
@@ -465,6 +478,7 @@ gistScanPage(IndexScanDesc scan, GISTSearchItem *pageItem, double *myDistances,
             * search.
             */
            GISTSearchItem *item;
+           int         nOrderBys = scan->numberOfOrderBys;
 
            oldcxt = MemoryContextSwitchTo(so->queueCxt);
 
@@ -499,8 +513,10 @@ gistScanPage(IndexScanDesc scan, GISTSearchItem *pageItem, double *myDistances,
            }
 
            /* Insert it into the queue using new distance data */
-           memcpy(item->distances, so->distances,
-                  sizeof(double) * scan->numberOfOrderBys);
+           memcpy(GISTSearchItemDistanceValues(item, nOrderBys),
+                  so->distanceValues, sizeof(double) * nOrderBys);
+           memcpy(GISTSearchItemDistanceNulls(item, nOrderBys),
+                  so->distanceNulls, sizeof(bool) * nOrderBys);
 
            pairingheap_add(so->queue, &item->phNode);
 
@@ -555,6 +571,8 @@ getNextNearest(IndexScanDesc scan)
    do
    {
        GISTSearchItem *item = getNextGISTSearchItem(so);
+       float8 *distanceValues = GISTSearchItemDistanceValues(item, scan->numberOfOrderBys);
+       bool *distanceNulls = GISTSearchItemDistanceNulls(item, scan->numberOfOrderBys);
 
        if (!item)
            break;
@@ -574,8 +592,8 @@ getNextNearest(IndexScanDesc scan)
                    if (!scan->xs_orderbynulls[i])
                        pfree(DatumGetPointer(scan->xs_orderbyvals[i]));
 #endif
-                   scan->xs_orderbyvals[i] = Float8GetDatum(item->distances[i]);
-                   scan->xs_orderbynulls[i] = false;
+                   scan->xs_orderbyvals[i] = Float8GetDatum(distanceValues[i]);
+                   scan->xs_orderbynulls[i] = distanceNulls[i];
                }
                else if (so->orderByTypes[i] == FLOAT4OID)
                {
@@ -585,8 +603,8 @@ getNextNearest(IndexScanDesc scan)
                    if (!scan->xs_orderbynulls[i])
                        pfree(DatumGetPointer(scan->xs_orderbyvals[i]));
 #endif
-                   scan->xs_orderbyvals[i] = Float4GetDatum((float4) item->distances[i]);
-                   scan->xs_orderbynulls[i] = false;
+                   scan->xs_orderbyvals[i] = Float4GetDatum(distanceValues[i]);
+                   scan->xs_orderbynulls[i] = distanceNulls[i];
                }
                else
                {
@@ -614,7 +632,10 @@ getNextNearest(IndexScanDesc scan)
            /* visit an index page, extract its items into queue */
            CHECK_FOR_INTERRUPTS();
 
-           gistScanPage(scan, item, item->distances, NULL, NULL);
+           gistScanPage(scan, item,
+                        GISTSearchItemDistanceValues(item, scan->numberOfOrderBys),
+                        GISTSearchItemDistanceNulls(item, scan->numberOfOrderBys),
+                        NULL, NULL);
        }
 
        pfree(item);
@@ -652,7 +673,7 @@ gistgettuple(IndexScanDesc scan, ScanDirection dir)
 
        fakeItem.blkno = GIST_ROOT_BLKNO;
        memset(&fakeItem.data.parentlsn, 0, sizeof(GistNSN));
-       gistScanPage(scan, &fakeItem, NULL, NULL, NULL);
+       gistScanPage(scan, &fakeItem, NULL, NULL, NULL, NULL);
    }
 
    if (scan->numberOfOrderBys > 0)
@@ -746,7 +767,10 @@ gistgettuple(IndexScanDesc scan, ScanDirection dir)
                 * this page, we fall out of the inner "do" and loop around to
                 * return them.
                 */
-               gistScanPage(scan, item, item->distances, NULL, NULL);
+               gistScanPage(scan, item,
+                            GISTSearchItemDistanceValues(item, scan->numberOfOrderBys),
+                            GISTSearchItemDistanceNulls(item, scan->numberOfOrderBys),
+                            NULL, NULL);
 
                pfree(item);
            } while (so->nPageData == 0);
@@ -777,7 +801,7 @@ gistgetbitmap(IndexScanDesc scan, TIDBitmap *tbm)
 
    fakeItem.blkno = GIST_ROOT_BLKNO;
    memset(&fakeItem.data.parentlsn, 0, sizeof(GistNSN));
-   gistScanPage(scan, &fakeItem, NULL, tbm, &ntids);
+   gistScanPage(scan, &fakeItem, NULL, NULL, tbm, &ntids);
 
    /*
     * While scanning a leaf page, ItemPointers of matching heap tuples will
@@ -792,7 +816,10 @@ gistgetbitmap(IndexScanDesc scan, TIDBitmap *tbm)
 
        CHECK_FOR_INTERRUPTS();
 
-       gistScanPage(scan, item, item->distances, tbm, &ntids);
+       gistScanPage(scan, item,
+                    GISTSearchItemDistanceValues(item, scan->numberOfOrderBys),
+                    GISTSearchItemDistanceNulls(item, scan->numberOfOrderBys),
+                    tbm, &ntids);
 
        pfree(item);
    }
index deff8d59f7ff7f832578e1821ba51a8999070cfe..0d69df98fcd3b84fff517be8297aa9c32136f633 100644 (file)
@@ -33,14 +33,30 @@ pairingheap_GISTSearchItem_cmp(const pairingheap_node *a, const pairingheap_node
    const GISTSearchItem *sb = (const GISTSearchItem *) b;
    IndexScanDesc scan = (IndexScanDesc) arg;
    int         i;
+   double     *da = GISTSearchItemDistanceValues(sa, scan->numberOfOrderBys),
+              *db = GISTSearchItemDistanceValues(sb, scan->numberOfOrderBys);
+   bool       *na = GISTSearchItemDistanceNulls(sa, scan->numberOfOrderBys),
+              *nb = GISTSearchItemDistanceNulls(sb, scan->numberOfOrderBys);
 
    /* Order according to distance comparison */
    for (i = 0; i < scan->numberOfOrderBys; i++)
    {
-       int         cmp = -float8_cmp_internal(sa->distances[i], sb->distances[i]);
+       if (na[i])
+       {
+           if (!nb[i])
+               return -1;
+       }
+       else if (nb[i])
+       {
+           return 1;
+       }
+       else
+       {
+           int         cmp = -float8_cmp_internal(da[i], db[i]);
 
-       if (cmp != 0)
-           return cmp;
+           if (cmp != 0)
+               return cmp;
+       }
    }
 
    /* Heap items go before inner pages, to ensure a depth-first search */
@@ -84,7 +100,8 @@ gistbeginscan(Relation r, int nkeys, int norderbys)
    so->queueCxt = giststate->scanCxt;  /* see gistrescan */
 
    /* workspaces with size dependent on numberOfOrderBys: */
-   so->distances = palloc(sizeof(double) * scan->numberOfOrderBys);
+   so->distanceValues = palloc(sizeof(double) * scan->numberOfOrderBys);
+   so->distanceNulls = palloc(sizeof(bool) * scan->numberOfOrderBys);
    so->qual_ok = true;         /* in case there are zero keys */
    if (scan->numberOfOrderBys > 0)
    {
index ed0fb634a52f9e201226d0e1da1e6b74eeac7368..68faa5bc28a2196696bf83db96f321d0463ed837 100644 (file)
@@ -136,13 +136,30 @@ typedef struct GISTSearchItem
        /* we must store parentlsn to detect whether a split occurred */
        GISTSearchHeapItem heap;    /* heap info, if heap tuple */
    }           data;
-   double      distances[FLEXIBLE_ARRAY_MEMBER];   /* numberOfOrderBys
-                                                    * entries */
+
+   /*
+    * This data structure is followed by arrays of distance values and
+    * distance null flags.  Size of both arrays is
+    * IndexScanDesc->numberOfOrderBys. See macros below for accessing those
+    * arrays.
+    */
 } GISTSearchItem;
 
 #define GISTSearchItemIsHeap(item) ((item).blkno == InvalidBlockNumber)
 
-#define SizeOfGISTSearchItem(n_distances) (offsetof(GISTSearchItem, distances) + sizeof(double) * (n_distances))
+#define SizeOfGISTSearchItem(n_distances) (DOUBLEALIGN(sizeof(GISTSearchItem)) + \
+   (sizeof(double) + sizeof(bool)) * (n_distances))
+
+/*
+ * We actually don't need n_distances compute pointer to distance values.
+ * Nevertheless take n_distances as argument to have same arguments list for
+ * GISTSearchItemDistanceValues() and GISTSearchItemDistanceNulls().
+ */
+#define GISTSearchItemDistanceValues(item, n_distances) \
+   ((double *) ((Pointer) (item) + DOUBLEALIGN(sizeof(GISTSearchItem))))
+
+#define GISTSearchItemDistanceNulls(item, n_distances) \
+   ((bool *) ((Pointer) (item) + DOUBLEALIGN(sizeof(GISTSearchItem)) + sizeof(double) * (n_distances)))
 
 /*
  * GISTScanOpaqueData: private state for a scan of a GiST index
@@ -158,7 +175,8 @@ typedef struct GISTScanOpaqueData
    bool        firstCall;      /* true until first gistgettuple call */
 
    /* pre-allocated workspace arrays */
-   double     *distances;      /* output area for gistindex_keytest */
+   double     *distanceValues; /* output area for gistindex_keytest */
+   bool       *distanceNulls;
 
    /* info about killed items if any (killedItems is NULL if never used) */
    OffsetNumber *killedItems;  /* offset numbers of killed items */