importfunctoolsimporthashlibfromdataclassesimportdataclassfromtypingimportIterableimportnumpyasnpfromopenaiimportOpenAIfromredelimportconfigfromredel.utilsimportbatchedVECTOR_CACHE_DIR=config.REDEL_CACHE_DIR/"embeddings"VECTOR_CACHE_DIR.mkdir(exist_ok=True,parents=True)# this is a function to init lazily only if we need it@functools.cachedefget_embedding_client():returnOpenAI()
[docs]@dataclassclassEmbeddingResult:idx:int"""The index of the input text in its input list."""embedding:np.ndarray"""The embedding returned by the server (an array of floats)."""
[docs]defget_embeddings(qs:list[str],model:str)->list[EmbeddingResult]:"""Get the embeddings for the inputs, caching them."""result=[]uncached=[]uncached_to_normal_idx={}fp_cache={}# find cached vecsforidx,textinenumerate(qs):text_hash=hashlib.sha256(text.encode()).hexdigest()cache_dir=VECTOR_CACHE_DIR/modelcache_dir.mkdir(exist_ok=True)fp=cache_dir/f"{text_hash}.npy"iffp.exists():try:vec=np.load(fp)result.append(EmbeddingResult(idx=idx,embedding=vec))exceptException:fp_cache[text]=fpuncached_to_normal_idx[len(uncached)]=idxuncached.append(text)else:fp_cache[text]=fpuncached_to_normal_idx[len(uncached)]=idxuncached.append(text)# embed uncached vecsifuncached:forembin_get_embeddings_openai_batch(uncached,model):text=uncached[emb.idx]idx=uncached_to_normal_idx[emb.idx]vec=emb.embeddingfp=fp_cache[text]np.save(fp,vec)result.append(EmbeddingResult(idx=idx,embedding=vec))assertlen(result)==len(qs)returnsorted(result,key=lambdar:r.idx)