Skip to content

Commit 53d2c67

Browse files
authored
Merge pull request #11 from LlmKira/dev
feat(Enum):add enum for /ai/generate_image
2 parents 47bc71a + 377c1d9 commit 53d2c67

File tree

5 files changed

+88
-48
lines changed

5 files changed

+88
-48
lines changed

playground/generate_image.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from novelai_python import APIError, Login
1313
from novelai_python import GenerateImageInfer, ImageGenerateResp, JwtCredential
14+
from novelai_python.sdk.ai.generate_image import Action
1415

1516
load_dotenv()
1617

@@ -25,7 +26,9 @@ async def main():
2526
).request()
2627
try:
2728
gen = GenerateImageInfer.build(
28-
prompt=f"1girl, winter, jacket, sfw, angel, flower,{enhance}")
29+
prompt=f"1girl, winter, jacket, sfw, angel, flower,{enhance}",
30+
action=Action.GENERATE,
31+
)
2932
cost = gen.calculate_cost(is_opus=True)
3033
print(f"charge: {cost} if you are vip3")
3134
print(f"charge: {gen.calculate_cost(is_opus=True)}")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "novelai-python"
3-
version = "0.2.2"
3+
version = "0.2.3"
44
description = "Novelai Python Binding With Pydantic"
55
authors = [
66
{ name = "sudoskys", email = "[email protected]" },

src/novelai_python/sdk/ai/generate_image.py

Lines changed: 64 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import json
77
import math
88
import random
9+
from enum import Enum, IntEnum
910
from io import BytesIO
10-
from typing import Optional, Union, Literal
11+
from typing import Optional, Union
1112
from zipfile import ZipFile
1213

1314
import httpx
@@ -22,6 +23,54 @@
2223
from ...utils import try_jsonfy, NovelAiMetadata
2324

2425

26+
class Sampler(Enum):
27+
K_EULER = "k_euler"
28+
K_EULER_ANCESTRAL = "k_euler_ancestral"
29+
K_DPMPP_2S_ANCESTRAL = "k_dpmpp_2s_ancestral"
30+
K_DPMPP_2M = "k_dpmpp_2m"
31+
K_DPMPP_SDE = "k_dpmpp_sde"
32+
DDIM_V3 = "ddim_v3"
33+
34+
35+
class NoiseSchedule(Enum):
36+
NATIVE = "native"
37+
KARRAS = "karras"
38+
EXPONENTIAL = "exponential"
39+
POLYEXPONENTIAL = "polyexponential"
40+
41+
42+
class UCPreset(IntEnum):
43+
TYPE0 = 0
44+
TYPE1 = 1
45+
TYPE2 = 2
46+
TYPE3 = 3
47+
48+
49+
class Action(Enum):
50+
GENERATE = "generate"
51+
IMG2IMG = "img2img"
52+
INFILL = "infill"
53+
54+
55+
class Model(Enum):
56+
NAI_DIFFUSION_3 = "nai-diffusion-3"
57+
NAI_DIFFUSION_3_INPAINTING = "nai-diffusion-3-inpainting"
58+
59+
60+
class Resolution(Enum):
61+
RES_512_768 = (512, 768)
62+
RES_768_512 = (768, 512)
63+
RES_640_640 = (640, 640)
64+
RES_832_1216 = (832, 1216)
65+
RES_1216_832 = (1216, 832)
66+
RES_1024_1024 = (1024, 1024)
67+
RES_1024_1536 = (1024, 1536)
68+
RES_1536_1024 = (1536, 1024)
69+
RES_1472_1472 = (1472, 1472)
70+
RES_1088_1920 = (1088, 1920)
71+
RES_1920_1088 = (1920, 1088)
72+
73+
2574
class GenerateImageInfer(BaseModel):
2675
_endpoint: Optional[str] = PrivateAttr("https://api.novelai.net")
2776
_charge: bool = PrivateAttr(False)
@@ -44,15 +93,15 @@ class Params(BaseModel):
4493

4594
n_samples: Optional[int] = Field(1, ge=1, le=8)
4695
negative_prompt: Optional[str] = ''
47-
noise_schedule: Optional[Union[str, Literal['native', 'polyexponential', 'exponential']]] = "native"
96+
noise_schedule: Optional[NoiseSchedule] = NoiseSchedule.NATIVE
4897

4998
# Misc
5099
params_version: Optional[int] = 1
51100
legacy: Optional[bool] = False
52101
legacy_v3_extend: Optional[bool] = False
53102

54103
qualityToggle: Optional[bool] = True
55-
sampler: Optional[str] = "k_euler"
104+
sampler: Optional[Sampler] = Sampler.K_EULER
56105
scale: Optional[float] = Field(6.0, ge=0, le=10, multiple_of=0.1)
57106
# Seed
58107
seed: Optional[int] = Field(
@@ -69,7 +118,7 @@ class Params(BaseModel):
69118
sm: Optional[bool] = False
70119
sm_dyn: Optional[bool] = False
71120
steps: Optional[int] = Field(28, ge=1, le=50)
72-
ucPreset: Optional[Literal[0, 1, 2, 3]] = 0
121+
ucPreset: Optional[UCPreset] = 0
73122
uncond_scale: Optional[float] = Field(1.0, ge=0, le=1.5, multiple_of=0.05)
74123
width: Optional[int] = Field(832, ge=64, le=49152)
75124

@@ -81,25 +130,6 @@ def validate_img2img(self):
81130
raise ValueError('Invalid Model Params For img2img2 mode... image should match add_original_image!')
82131
return self
83132

84-
@field_validator('sampler')
85-
def sampler_validator(cls, v: str):
86-
if v not in ["k_euler", "k_euler_ancestral", 'k_dpmpp_2s_ancestral', "k_dpmpp_2m", "k_dpmpp_sde",
87-
"ddim_v3"]:
88-
raise ValueError("Invalid sampler.")
89-
return v
90-
91-
@field_validator('noise_schedule')
92-
def noise_schedule_validator(cls, v: str):
93-
if v not in ["native", "karras", "exponential", "polyexponential"]:
94-
raise ValueError("Invalid noise_schedule.")
95-
return v
96-
97-
@field_validator('steps')
98-
def steps_validator(cls, v: int):
99-
if v > 28:
100-
logger.warning(f"steps {v} > 28, maybe charge more.")
101-
return v
102-
103133
@field_validator('width')
104134
def width_validator(cls, v: int):
105135
"""
@@ -122,9 +152,9 @@ def height_validator(cls, v: int):
122152
raise ValueError("Invalid height, must be multiple of 64.")
123153
return v
124154

125-
action: Union[str, Literal["generate", "img2img", "infill"]] = "generate"
155+
action: Union[str, Action] = Field(Action.GENERATE, description="Mode for img generate")
126156
input: str = "1girl, best quality, amazing quality, very aesthetic, absurdres"
127-
model: Optional[str] = "nai-diffusion-3"
157+
model: Optional[Model] = "nai-diffusion-3"
128158
parameters: Params = Params()
129159
model_config = ConfigDict(extra="ignore")
130160

@@ -170,18 +200,6 @@ def endpoint(self):
170200
def endpoint(self, value):
171201
self._endpoint = value
172202

173-
@staticmethod
174-
def valid_wh():
175-
"""
176-
宽高
177-
:return:
178-
"""
179-
return [
180-
(832, 1216),
181-
(1216, 832),
182-
(1024, 1024),
183-
]
184-
185203
def calculate_cost(self, is_opus: bool = False):
186204
"""
187205
Calculate the Anlas cost of current parameters.
@@ -227,17 +245,18 @@ def calculate_cost(self, is_opus: bool = False):
227245
def build(cls,
228246
prompt: str,
229247
*,
230-
model: str = "nai-diffusion-3",
231-
action: Literal['generate', 'img2img'] = 'generate',
248+
model: Union[Model, str] = "nai-diffusion-3",
249+
action: Union[Action, str] = 'generate',
232250
negative_prompt: str = "",
233251
seed: int = None,
234252
steps: int = 28,
235253
cfg_rescale: int = 0,
236-
sampler: str = "k_euler",
254+
sampler: Union[Sampler, str] = Sampler.K_EULER,
237255
width: int = 832,
238256
height: int = 1216,
239257
qualityToggle: bool = True,
240-
ucPreset: int = 0,
258+
ucPreset: Union[UCPreset, int] = UCPreset.TYPE0,
259+
**kwargs
241260
):
242261
"""
243262
正负面, step, cfg, 采样方式, seed
@@ -257,7 +276,7 @@ def build(cls,
257276
"""
258277
assert isinstance(prompt, str)
259278
_negative_prompt = negative_prompt
260-
param = {
279+
kwargs.update({
261280
"negative_prompt": _negative_prompt,
262281
"seed": seed,
263282
"steps": steps,
@@ -267,9 +286,9 @@ def build(cls,
267286
"height": height,
268287
"qualityToggle": qualityToggle,
269288
"ucPreset": ucPreset,
270-
}
289+
})
271290
# 清理空值
272-
param = {k: v for k, v in param.items() if v is not None}
291+
param = {k: v for k, v in kwargs.items() if v is not None}
273292
return cls(
274293
input=prompt,
275294
model=model,
@@ -288,7 +307,7 @@ async def generate(self, session: Union[AsyncSession, "CredentialBase"],
288307
"""
289308
if isinstance(session, CredentialBase):
290309
session = await session.get_session()
291-
request_data = self.model_dump(exclude_none=True)
310+
request_data = self.model_dump(mode="json", exclude_none=True)
292311
logger.debug(f"Request Data: {request_data}")
293312
try:
294313
assert hasattr(session, "post"), "session must have post method."

src/novelai_python/sdk/user/login.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ async def request(self,
6565
Request to get user access token
6666
:return:
6767
"""
68-
request_data = self.model_dump(exclude_none=True)
68+
request_data = self.model_dump(mode="json", exclude_none=True)
6969
logger.debug("Login")
7070
try:
7171
assert hasattr(self.session, "post"), "session must have get method."

src/novelai_python/server.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from .credential import JwtCredential, SecretStr
1717
from .sdk.ai.generate_image import GenerateImageInfer
18+
from .sdk.user.information import Information
1819
from .sdk.user.login import Login
1920
from .sdk.user.subscription import Subscription
2021

@@ -55,6 +56,23 @@ async def login(
5556
return JSONResponse(status_code=500, content=e.__dict__)
5657

5758

59+
@app.get("/user/information")
60+
async def information(
61+
current_token: str = Depends(get_current_token)
62+
):
63+
"""
64+
用户信息
65+
:param current_token: Authorization
66+
:return:
67+
"""
68+
try:
69+
_result = await Information().request(session=get_session(current_token))
70+
return _result.model_dump()
71+
except Exception as e:
72+
logger.exception(e)
73+
return JSONResponse(status_code=500, content=e.__dict__)
74+
75+
5876
@app.get("/user/subscription")
5977
async def subscription(
6078
current_token: str = Depends(get_current_token)

0 commit comments

Comments
 (0)