diff --git a/README.md b/README.md index d95077d..d32696d 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,68 @@ # sql +*author: junjie.jiang* + + +
+ + +## Desription + +Read data from sqlite or mysql. + +
+ + +## Code Example + +### Example + +```python +from towhee import DataLoader, pipe, ops +p = ( + pipe.input('image_path') + .map('image_path', 'image', ops.image_decode.cv2()) + .map('image', 'vec', ops.image_embedding.timm(model_name='resnet50')) + .output('vec') + +) + +for data in DataLoader(ops.data_source.sql('sqlite:///./sqlite.db')): + print(p(data).to_list(kv_format=True)) + +# batch +for data in DataLoader(ops.data_source.glob('sqlite:///./sqlite.db'), batch_size=10): + p.batch(data) +``` + +**Parameters:** + + +***sql_url:*** *str* + +the url of the sql database for cache, such as '+://:@:/' + +sqlite: sqlite:///./sqlite.db + +mysql: mysql+pymysql://root:123456@127.0.0.1:3306/mysql + + +***table_name:*** *str* + +table name + +***cols:*** *str* + +The columns to be queried, default to *, indicating all columns + +If you want to query specific columns, use the column names and separate them with `,`, such as 'id,image_path,label' + + +***where:*** *str* + +Where conditional statement, for example: id > 100 + +***limit:* *int* + +The default value is 500. If set to None, all data will be returned. + diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..9b7fd94 --- /dev/null +++ b/__init__.py @@ -0,0 +1,4 @@ +from .sql_storage import SqlStorage + +def sql(*args, **kwargs): + return SqlStorage(*args, **kwargs) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..39fb2be --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +sqlalchemy diff --git a/sql_storage.py b/sql_storage.py new file mode 100644 index 0000000..c217d73 --- /dev/null +++ b/sql_storage.py @@ -0,0 +1,34 @@ + +from sqlalchemy import create_engine + +from towhee.operator import PyOperator + + +class SqlStorage(PyOperator): + """ + Using sqlalchemy to manage SQLite, PostgreSQL, MySQL, MariaDB, SQL Server and Oracle. + + Args: + sql_url: the url of the sql database for cache, such as '+://:@:/', + 'sqlite:///./sqlite.db' for 'sqlite', + 'mysql+pymysql://root:123456@127.0.0.1:3306/mysql' for 'mysql', + table_name: the table name for sql database. + cols: list of cols. + """ + + def __init__(self, sql_url: str, table_name:str, cols: str = '*', where: str = None, limit: int = 500): + self._sql_url = sql_url + self._sql = "SELECT {} FROM {}".format(cols, table_name) + if where: + self._sql = "{} WHERE {}".format(self._sql, where) + if limit: + self._sql = "{} LIMIT {}".format(self._sql, limit) + + def __call__(self): + engine = create_engine(self._sql_url) + with engine.connect() as connection: + result = connection.execute(self._sql) + yield from result + + +