2 Star 1 Fork 1

chenjun2hao/Faiss.learning

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
IndexLattice.cpp 3.47 KB
一键复制 编辑 原始数据 按行查看 历史
chenj 提交于 2020-08-11 15:45 . add commit for learning
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
// -*- c++ -*-
#include <faiss/IndexLattice.h>
#include <faiss/utils/hamming.h> // for the bitstring routines
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/distances.h>
namespace faiss {
IndexLattice::IndexLattice (idx_t d, int nsq, int scale_nbit, int r2):
Index (d),
nsq (nsq),
dsq (d / nsq),
zn_sphere_codec (dsq, r2),
scale_nbit (scale_nbit)
{
FAISS_THROW_IF_NOT (d % nsq == 0);
lattice_nbit = 0;
while (!( ((uint64_t)1 << lattice_nbit) >= zn_sphere_codec.nv)) {
lattice_nbit++;
}
int total_nbit = (lattice_nbit + scale_nbit) * nsq;
code_size = (total_nbit + 7) / 8;
is_trained = false;
}
void IndexLattice::train(idx_t n, const float* x)
{
// compute ranges per sub-block
trained.resize (nsq * 2);
float * mins = trained.data();
float * maxs = trained.data() + nsq;
for (int sq = 0; sq < nsq; sq++) {
mins[sq] = HUGE_VAL;
maxs[sq] = -1;
}
for (idx_t i = 0; i < n; i++) {
for (int sq = 0; sq < nsq; sq++) {
float norm2 = fvec_norm_L2sqr (x + i * d + sq * dsq, dsq);
if (norm2 > maxs[sq]) maxs[sq] = norm2;
if (norm2 < mins[sq]) mins[sq] = norm2;
}
}
for (int sq = 0; sq < nsq; sq++) {
mins[sq] = sqrtf (mins[sq]);
maxs[sq] = sqrtf (maxs[sq]);
}
is_trained = true;
}
/* The standalone codec interface */
size_t IndexLattice::sa_code_size () const
{
return code_size;
}
void IndexLattice::sa_encode (idx_t n, const float *x, uint8_t *codes) const
{
const float * mins = trained.data();
const float * maxs = mins + nsq;
int64_t sc = int64_t(1) << scale_nbit;
#pragma omp parallel for
for (idx_t i = 0; i < n; i++) {
BitstringWriter wr(codes + i * code_size, code_size);
const float *xi = x + i * d;
for (int j = 0; j < nsq; j++) {
float nj =
(sqrtf(fvec_norm_L2sqr(xi, dsq)) - mins[j])
* sc / (maxs[j] - mins[j]);
if (nj < 0) nj = 0;
if (nj >= sc) nj = sc - 1;
wr.write((int64_t)nj, scale_nbit);
wr.write(zn_sphere_codec.encode(xi), lattice_nbit);
xi += dsq;
}
}
}
void IndexLattice::sa_decode (idx_t n, const uint8_t *codes, float *x) const
{
const float * mins = trained.data();
const float * maxs = mins + nsq;
float sc = int64_t(1) << scale_nbit;
float r = sqrtf(zn_sphere_codec.r2);
#pragma omp parallel for
for (idx_t i = 0; i < n; i++) {
BitstringReader rd(codes + i * code_size, code_size);
float *xi = x + i * d;
for (int j = 0; j < nsq; j++) {
float norm =
(rd.read (scale_nbit) + 0.5) *
(maxs[j] - mins[j]) / sc + mins[j];
norm /= r;
zn_sphere_codec.decode (rd.read (lattice_nbit), xi);
for (int l = 0; l < dsq; l++) {
xi[l] *= norm;
}
xi += dsq;
}
}
}
void IndexLattice::add(idx_t , const float* )
{
FAISS_THROW_MSG("not implemented");
}
void IndexLattice::search(idx_t , const float* , idx_t ,
float* , idx_t* ) const
{
FAISS_THROW_MSG("not implemented");
}
void IndexLattice::reset()
{
FAISS_THROW_MSG("not implemented");
}
} // namespace faiss
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/chenjun2hao/Faiss.learning.git
git@gitee.com:chenjun2hao/Faiss.learning.git
chenjun2hao
Faiss.learning
Faiss.learning
master

搜索帮助