sql
copied
4 changed files with 105 additions and 0 deletions
@ -1,2 +1,68 @@ |
|||||
# sql |
# sql |
||||
|
|
||||
|
*author: junjie.jiang* |
||||
|
|
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
|
||||
|
## Desription |
||||
|
|
||||
|
Read data from sqlite or mysql. |
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
|
||||
|
## Code Example |
||||
|
|
||||
|
### Example |
||||
|
|
||||
|
```python |
||||
|
from towhee import DataLoader, pipe, ops |
||||
|
p = ( |
||||
|
pipe.input('image_path') |
||||
|
.map('image_path', 'image', ops.image_decode.cv2()) |
||||
|
.map('image', 'vec', ops.image_embedding.timm(model_name='resnet50')) |
||||
|
.output('vec') |
||||
|
|
||||
|
) |
||||
|
|
||||
|
for data in DataLoader(ops.data_source.sql('sqlite:///./sqlite.db')): |
||||
|
print(p(data).to_list(kv_format=True)) |
||||
|
|
||||
|
# batch |
||||
|
for data in DataLoader(ops.data_source.glob('sqlite:///./sqlite.db'), batch_size=10): |
||||
|
p.batch(data) |
||||
|
``` |
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
|
||||
|
***sql_url:*** *str* |
||||
|
|
||||
|
the url of the sql database for cache, such as '<db_type>+<db_driver>://<username>:<password>@<host>:<port>/<database>' |
||||
|
|
||||
|
sqlite: sqlite:///./sqlite.db |
||||
|
|
||||
|
mysql: mysql+pymysql://root:123456@127.0.0.1:3306/mysql |
||||
|
|
||||
|
|
||||
|
***table_name:*** *str* |
||||
|
|
||||
|
table name |
||||
|
|
||||
|
***cols:*** *str* |
||||
|
|
||||
|
The columns to be queried, default to *, indicating all columns |
||||
|
|
||||
|
If you want to query specific columns, use the column names and separate them with `,`, such as 'id,image_path,label' |
||||
|
|
||||
|
|
||||
|
***where:*** *str* |
||||
|
|
||||
|
Where conditional statement, for example: id > 100 |
||||
|
|
||||
|
***limit:* *int* |
||||
|
|
||||
|
The default value is 500. If set to None, all data will be returned. |
||||
|
|
||||
|
@ -0,0 +1,4 @@ |
|||||
|
from .sql_storage import SqlStorage |
||||
|
|
||||
|
def sql(*args, **kwargs): |
||||
|
return SqlStorage(*args, **kwargs) |
@ -0,0 +1 @@ |
|||||
|
sqlalchemy |
@ -0,0 +1,34 @@ |
|||||
|
|
||||
|
from sqlalchemy import create_engine |
||||
|
|
||||
|
from towhee.operator import PyOperator |
||||
|
|
||||
|
|
||||
|
class SqlStorage(PyOperator): |
||||
|
""" |
||||
|
Using sqlalchemy to manage SQLite, PostgreSQL, MySQL, MariaDB, SQL Server and Oracle. |
||||
|
|
||||
|
Args: |
||||
|
sql_url: the url of the sql database for cache, such as '<db_type>+<db_driver>://<username>:<password>@<host>:<port>/<database>', |
||||
|
'sqlite:///./sqlite.db' for 'sqlite', |
||||
|
'mysql+pymysql://root:123456@127.0.0.1:3306/mysql' for 'mysql', |
||||
|
table_name: the table name for sql database. |
||||
|
cols: list of cols. |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, sql_url: str, table_name:str, cols: str = '*', where: str = None, limit: int = 500): |
||||
|
self._sql_url = sql_url |
||||
|
self._sql = "SELECT {} FROM {}".format(cols, table_name) |
||||
|
if where: |
||||
|
self._sql = "{} WHERE {}".format(self._sql, where) |
||||
|
if limit: |
||||
|
self._sql = "{} LIMIT {}".format(self._sql, limit) |
||||
|
|
||||
|
def __call__(self): |
||||
|
engine = create_engine(self._sql_url) |
||||
|
with engine.connect() as connection: |
||||
|
result = connection.execute(self._sql) |
||||
|
yield from result |
||||
|
|
||||
|
|
||||
|
|
Loading…
Reference in new issue