Skip to content

Commit fa69302

Browse files
authored
Merge pull request #1 from LlmKira/dev
🔨 chore(generate_image): Clean up generate_image.py
2 parents e3db29f + cfa0347 commit fa69302

File tree

10 files changed

+577
-18
lines changed

10 files changed

+577
-18
lines changed

README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,22 @@ loop.run_until_complete(main())
4949

5050
```
5151

52+
#### Random Prompt
53+
54+
```python
55+
from novelai_python.utils.random_prompt import RandomPromptGenerator
56+
57+
s = RandomPromptGenerator(nsfw_enabled=False).generate()
58+
print(s)
59+
```
60+
61+
#### Run A Server
62+
63+
```shell
64+
pip install novelai_python
65+
python3 -m novelai_python.server -h '0.0.0.0' -p 7888
66+
```
67+
5268
## Acknowledgements 🙏
5369

5470
[BackEnd](https://api.novelai.net/docs)

pdm.lock

Lines changed: 341 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "novelai-python"
3-
version = "0.1.5"
3+
version = "0.1.6"
44
description = "Novelai Python Binding With Pydantic"
55
authors = [
66
{ name = "sudoskys", email = "[email protected]" },
@@ -14,6 +14,8 @@ dependencies = [
1414
"shortuuid>=1.0.11",
1515
"Pillow>=10.2.0",
1616
"curl-cffi>=0.5.10",
17+
"fastapi>=0.109.0",
18+
"uvicorn[standard]>=0.27.0.post1",
1719
]
1820
requires-python = ">=3.8"
1921
readme = "README.md"

src/novelai_python/_exceptions.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,11 @@ class APIError(NovelAiError):
2222
code: Optional[str] = None
2323
response: Union[Dict[str, Any], str] = None
2424

25-
def __init__(self, message: str, request: Any, response: Any, status_code: str) -> None:
26-
super().__init__(message)
25+
def __init__(self, message: str, request: dict, response: Union[dict, str], status_code: str) -> None:
26+
if not isinstance(response, dict):
27+
response = {"error": f"data type error, should be dict, but got {type(response)}"}
28+
if not isinstance(request, dict):
29+
request = {"error": f"data type error, should be dict, but got {type(request)}"}
2730
self.request = request
2831
self.message = message
2932
self.code = status_code
@@ -35,5 +38,5 @@ class AuthError(APIError):
3538
AuthError is raised when the API returns an error.
3639
"""
3740

38-
def __init__(self, message: str, request: Any, response: Any, status_code: str) -> None:
41+
def __init__(self, message: str, request: dict, response: Union[dict, str], status_code: str) -> None:
3942
super().__init__(message, request, response, status_code)

src/novelai_python/credential/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# @File : __init__.py.py
55
# @Software: PyCharm
66
from .JwtToken import JwtCredential
7-
7+
from pydantic import SecretStr
88
__all__ = [
9-
"JwtCredential"
9+
"JwtCredential",
10+
"SecretStr"
1011
]

src/novelai_python/sdk/ai/generate_image.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
import json
77
import random
88
from io import BytesIO
9-
from typing import Optional, Union
9+
from typing import Optional, Union, Literal
1010
from zipfile import ZipFile
1111

1212
import httpx
1313
from curl_cffi.requests import AsyncSession
1414
from loguru import logger
15-
from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator
15+
from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator, model_validator
1616

1717
from ..._exceptions import APIError, AuthError
1818
from ..._response import ImageGenerateResp
@@ -30,6 +30,7 @@ class Params(BaseModel):
3030
controlnet_strength: Optional[int] = 1
3131
dynamic_thresholding: Optional[bool] = False
3232
height: Optional[int] = 1216
33+
image: Optional[str] = None # img2img,base64
3334
legacy: Optional[bool] = False
3435
legacy_v3_extend: Optional[bool] = False
3536
n_samples: Optional[int] = 1
@@ -41,7 +42,7 @@ class Params(BaseModel):
4142
"extra digit, fewer digits, cropped, worst quality, low quality, normal quality, "
4243
"jpeg artifacts, signature, watermark, username, blurry"
4344
)
44-
noise_schedule: Optional[str] = "native"
45+
noise_schedule: Optional[Union[str, Literal['native', 'polyexponential', 'exponential']]] = "native"
4546
params_version: Optional[int] = 1
4647
qualityToggle: Optional[bool] = True
4748
sampler: Optional[str] = "k_euler"
@@ -60,15 +61,25 @@ def seed_validator(cls, v: int):
6061
v = random.randint(0, 2 ** 32 - 1)
6162
return v
6263

64+
@model_validator(mode="after")
65+
def validate_img2img(self):
66+
67+
image = True if self.image else False
68+
add_origin = True if self.add_original_image else False
69+
if image != add_origin:
70+
raise ValueError('Invalid Model Params For img2img2 mode... image should match add_original_image!')
71+
return self
72+
6373
@field_validator('sampler')
6474
def sampler_validator(cls, v: str):
65-
if v not in ["k_euler", "k_euler_ancestral", "k_dpmpp_2m", "k_dpmpp_sde", "ddim_v3"]:
75+
if v not in ["k_euler", "k_euler_ancestral", 'k_dpmpp_2s_ancestral', "k_dpmpp_2m", "k_dpmpp_sde",
76+
"ddim_v3"]:
6677
raise ValueError("Invalid sampler.")
6778
return v
6879

6980
@field_validator('noise_schedule')
7081
def noise_schedule_validator(cls, v: str):
71-
if v not in ["native", "exponential", "polyexponential"]:
82+
if v not in ["native", "karras", "exponential", "polyexponential"]:
7283
raise ValueError("Invalid noise_schedule.")
7384
return v
7485

@@ -111,7 +122,7 @@ def n_samples_validator(cls, v: int):
111122
raise ValueError("Invalid n_samples, must be less than 8.")
112123
return v
113124

114-
action: Optional[str] = "generate"
125+
action: Union[str, Literal["generate", "img2img"]] = "generate"
115126
input: str = "1girl, best quality, amazing quality, very aesthetic, absurdres"
116127
model: Optional[str] = "nai-diffusion-3"
117128
parameters: Params = Params()
@@ -162,6 +173,8 @@ def validate_charge(self):
162173
def build(cls,
163174
prompt: str,
164175
*,
176+
model: str = "nai-diffusion-3",
177+
action: Literal['generate', 'img2img'] = 'generate',
165178
negative_prompt: Optional[str] = None,
166179
override_negative_prompt: bool = False,
167180
seed: int = -1,
@@ -175,6 +188,8 @@ def build(cls,
175188
正负面, step, cfg, 采样方式, seed
176189
:param override_negative_prompt:
177190
:param prompt:
191+
:param model:
192+
:param action: Mode for img generate
178193
:param negative_prompt:
179194
:param seed:
180195
:param steps:
@@ -209,13 +224,15 @@ def build(cls,
209224
param = {k: v for k, v in param.items() if v is not None}
210225
return cls(
211226
input=prompt,
227+
model=model,
228+
action=action,
212229
parameters=cls.Params(**param)
213230
)
214231

215232
async def generate(self, session: Union[AsyncSession, JwtCredential]) -> ImageGenerateResp:
216233
if isinstance(session, JwtCredential):
217234
session = session.session
218-
request_data = self.model_dump()
235+
request_data = self.model_dump(exclude_none=True)
219236
logger.debug(f"Request Data: {request_data}")
220237
try:
221238
assert hasattr(session, "post"), "session must have post method."

src/novelai_python/server.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# -*- coding: utf-8 -*-
2+
# @Time : 2024/1/30 下午11:05
3+
# @Author : sudoskys
4+
# @File : server.py
5+
# @Software: PyCharm
6+
import io
7+
import sys
8+
import zipfile
9+
10+
import uvicorn
11+
from fastapi import FastAPI, Depends, Security
12+
from fastapi.security import APIKeyHeader
13+
from starlette.responses import JSONResponse, StreamingResponse
14+
15+
from .credential import JwtCredential, SecretStr
16+
from .sdk.ai.generate_image import GenerateImageInfer
17+
18+
app = FastAPI()
19+
token_key = APIKeyHeader(name="Authorization")
20+
session = {}
21+
22+
23+
def get_session(token: str):
24+
if token not in session:
25+
session[token] = JwtCredential(jwt_token=SecretStr(token))
26+
return session[token]
27+
28+
29+
def get_current_token(auth_key: str = Security(token_key)):
30+
return auth_key
31+
32+
33+
@app.get("/health")
34+
async def health():
35+
return {"status": "ok"}
36+
37+
38+
@app.post("/ai/generate_image")
39+
async def generate_image(
40+
req: GenerateImageInfer,
41+
current_token: str = Depends(get_current_token)
42+
):
43+
"""
44+
生成图片
45+
:param current_token: Authorization
46+
:param req: GenerateImageInfer
47+
:return:
48+
"""
49+
try:
50+
_result = await req.generate(session=get_session(current_token))
51+
zip_file_bytes = io.BytesIO()
52+
with zipfile.ZipFile(zip_file_bytes, mode="w", compression=zipfile.ZIP_DEFLATED) as zip_file:
53+
for file in _result.files:
54+
zip_file.writestr(zinfo_or_arcname=file[0], data=file[1])
55+
# return the zip file
56+
zip_file_bytes.seek(0)
57+
return StreamingResponse(zip_file_bytes, media_type='application/zip', headers={
58+
'Content-Disposition': 'attachment;filename=multiple_files.zip'
59+
})
60+
except Exception as e:
61+
return JSONResponse(status_code=500, content=e.__dict__)
62+
63+
64+
# 获取输入参数
65+
def usage():
66+
print("Usage: python -m novelai_python.server -h <host> -p <port>")
67+
sys.exit(0)
68+
69+
70+
if __name__ == '__main__':
71+
import getopt
72+
73+
opts = {}
74+
try:
75+
opts, args = getopt.getopt(sys.argv[1:], "h:p:", ["host=", "port="])
76+
except getopt.GetoptError:
77+
usage()
78+
opts = dict(opts)
79+
server_host = opts.get("-h", "0.0.0.0")
80+
server_port = int(opts.get("-p", 10087))
81+
print(f"Docs: http://{server_host}:{server_port}/docs")
82+
uvicorn.run(app, host=server_host, port=server_port)

src/novelai_python/utils/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import json
88
from typing import Union
99

10+
from loguru import logger
11+
1012
from .hash import NovelAiMetadata
1113

1214

@@ -28,5 +30,6 @@ def try_jsonfy(obj: Union[str, dict, list, tuple]):
2830
"""
2931
try:
3032
return json.loads(obj)
31-
except json.JSONDecoder:
33+
except Exception:
34+
logger.error(f"Decode Error {obj}")
3235
return f"Decode Error {type(obj)}"

tests/test_random_prompt.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# -*- coding: utf-8 -*-
2+
# @Time : 2024/1/27 上午10:41
3+
# @Author : sudoskys
4+
# @File : test_random_prompt.py
5+
# @Software: PyCharm
6+
7+
from src.novelai_python.utils.random_prompt import RandomPromptGenerator
8+
9+
10+
def test_generate_returns_string():
11+
generator = RandomPromptGenerator(nsfw_enabled=True)
12+
result = generator.generate()
13+
assert isinstance(result, str)
14+
15+
16+
def test_generate_returns_non_empty_string():
17+
generator = RandomPromptGenerator(nsfw_enabled=True)
18+
result = generator.generate()
19+
assert len(result) > 0
20+
21+
22+
def test_generate_returns_different_results():
23+
generator = RandomPromptGenerator(nsfw_enabled=True)
24+
result1 = generator.generate()
25+
result2 = generator.generate()
26+
assert result1 != result2
27+
28+
29+
def test_generate_with_nsfw_disabled():
30+
generator = RandomPromptGenerator(nsfw_enabled=False)
31+
result = generator.generate()
32+
assert 'nsfw' not in result
33+
34+
35+
def test_generate_with_nsfw_enabled():
36+
generator = RandomPromptGenerator(nsfw_enabled=True)
37+
result = generator.generate()
38+
assert 'nsfw' in result
39+
40+
41+
def test_get_weighted_choice_returns_string():
42+
generator = RandomPromptGenerator(nsfw_enabled=True)
43+
result = generator.get_weighted_choice([['tag1', 1], ['tag2', 2]], [])
44+
assert isinstance(result, str)
45+
46+
47+
def test_get_weighted_choice_returns_valid_tag():
48+
generator = RandomPromptGenerator(nsfw_enabled=True)
49+
result = generator.get_weighted_choice([['tag1', 1], ['tag2', 2]], [])
50+
assert result in ['tag1', 'tag2']
51+
52+
53+
def test_character_features_returns_list():
54+
generator = RandomPromptGenerator(nsfw_enabled=True)
55+
result = generator.character_features('m', 'front', True, 1)
56+
assert isinstance(result, list)
57+
58+
59+
def test_character_features_returns_non_empty_list():
60+
generator = RandomPromptGenerator(nsfw_enabled=True)
61+
result = generator.character_features('m', 'front', True, 1)
62+
assert len(result) > 0
63+
64+
65+
def test_character_features_with_different_genders():
66+
generator = RandomPromptGenerator(nsfw_enabled=True)
67+
result_m = generator.character_features('m', 'front', True, 1)
68+
result_f = generator.character_features('f', 'front', True, 1)
69+
result_o = generator.character_features('o', 'front', True, 1)
70+
assert result_m != result_f != result_o

tests/test_server_run.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# -*- coding: utf-8 -*-
2+
# @Time : 2024/1/30 下午11:52
3+
# @Author : sudoskys
4+
# @File : test_server_run.py
5+
# @Software: PyCharm
6+
from fastapi.testclient import TestClient
7+
8+
from src.novelai_python.sdk.ai.generate_image import GenerateImageInfer
9+
from src.novelai_python.server import app, get_session
10+
11+
client = TestClient(app)
12+
13+
14+
def test_health_check():
15+
response = client.get("/health")
16+
assert response.status_code == 200
17+
assert response.json() == {"status": "ok"}
18+
19+
20+
def test_generate_image_with_valid_token():
21+
valid_token = "valid_token"
22+
get_session(valid_token) # to simulate a valid session
23+
response = client.post(
24+
"/ai/generate_image",
25+
headers={"Authorization": valid_token},
26+
json=GenerateImageInfer(input="1girl").model_dump()
27+
)
28+
assert response.status_code == 500

0 commit comments

Comments
 (0)