diff --git a/README.md b/README.md index b87ca52..7250308 100644 --- a/README.md +++ b/README.md @@ -33,12 +33,19 @@ to generate a text embedding for the sentence "Hello, world.". *Write the pipeline*: ```python -import towhee +from towhee.dc2 import pipe, ops, DataCollection -towhee.dc(["Hello, world."]) \ - .text_embedding.dpr(model_name="facebook/dpr-ctx_encoder-single-nq-base") +p = ( + pipe.input('text') + .map('text', 'vec', ops.text_embedding.dpr(model_name='facebook/dpr-ctx_encoder-single-nq-base')) + .output('text', 'vec') +) + +DataCollection(p('Hello, world.')).show() ``` + +
## Factory Constructor diff --git a/dpr.py b/dpr.py index 004626a..cf84ffd 100644 --- a/dpr.py +++ b/dpr.py @@ -50,7 +50,7 @@ class Dpr(NNOperator): except Exception as e: log.error(f'Invalid input for the model: {self.model_name}') raise e - vec = embeddings.detach().numpy() + vec = embeddings.squeeze(0).detach().numpy() return vec diff --git a/result.png b/result.png new file mode 100644 index 0000000..28e3918 Binary files /dev/null and b/result.png differ