Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 100 additions & 2 deletions aispeech/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,103 @@
# 智能对话

[官方文档](https://developers.weixin.qq.com/doc/aispeech/platform/INTERFACEDOCUMENT.html)
[官方文档](https://developers.weixin.qq.com/doc/aispeech/confapi/dialog/token.html)

## 快速入门
## 快速入门

```go
import (
"github.com/silenceper/wechat/v2"
"github.com/silenceper/wechat/v2/aispeech/config"
"github.com/silenceper/wechat/v2/aispeech/dialog"
"github.com/silenceper/wechat/v2/cache"
)

wc := wechat.NewWechat()
memory := cache.NewMemory()

ai := wc.GetAISpeech(&config.Config{
AppID: "xxx",
Token: "xxx",
AESKey: "xxx",
Account: "admin",
Cache: memory,
})

dialogClient := ai.GetDialog()

accessToken, err := dialogClient.GetAccessToken()
if err != nil {
return err
}

res, err := dialogClient.Query(&dialog.QueryRequest{
Query: "你好",
Env: "online",
UserID: "user-1",
})
```

## 简单问答导入与发布

```go
task, err := dialogClient.ImportJSON(&dialog.ImportJSONRequest{
Mode: 0,
Data: []dialog.BotIntent{{
Skill: "售前咨询",
Intent: "查询营业时间",
Disable: false,
Questions: []string{"你们几点开门", "营业时间是什么时候"},
Answers: []string{"我们的营业时间是周一至周五 9:00-18:00"},
}},
})
if err != nil {
return err
}

asyncResult, err := dialogClient.FetchAsync(&dialog.FetchAsyncRequest{
TaskID: task.TaskID,
})
if err != nil {
return err
}

publish, err := dialogClient.Publish()
if err != nil {
return err
}

progress, err := dialogClient.GetEffectiveProgress(&dialog.EffectiveProgressRequest{
Env: "online",
})
```

当前封装微信智能对话开放接口中“接入智能对话”的 6 个接口:

- `POST /v2/token`
- `POST /v2/bot/import/json`
- `POST /v2/async/fetch`
- `POST /v2/bot/publish`
- `POST /v2/bot/effective_progress`
- `POST /v2/bot/query`

## 真实接口验证

默认单元测试不会访问真实微信接口。如需验证 token 和 query 只读链路,可设置以下环境变量后运行:

```bash
AISPEECH_INTEGRATION=1
AISPEECH_APPID=xxx
AISPEECH_TOKEN=xxx
AISPEECH_AES_KEY=xxx
go test ./aispeech/dialog -run TestIntegrationAccessTokenAndQuery
```

如需验证导入、异步任务查询、发布、发布进度查询和命中回答的完整链路,可显式开启变更型测试。该测试会导入一个 `CodexSmokeTest` 临时问答并发布到线上机器人:

```bash
AISPEECH_INTEGRATION_MUTATION=1
AISPEECH_APPID=xxx
AISPEECH_TOKEN=xxx
AISPEECH_AES_KEY=xxx
go test ./aispeech/dialog -run TestIntegrationFullDialogFlow -v
```
60 changes: 60 additions & 0 deletions aispeech/aispeech.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package aispeech

import (
stdcontext "context"

"github.com/silenceper/wechat/v2/aispeech/config"
"github.com/silenceper/wechat/v2/aispeech/context"
"github.com/silenceper/wechat/v2/aispeech/dialog"
"github.com/silenceper/wechat/v2/credential"
)

// AISpeech 微信智能对话相关 API.
type AISpeech struct {
ctx *context.Context
dialog *dialog.Dialog
}

// NewAISpeech 实例化智能对话 API.
func NewAISpeech(cfg *config.Config) *AISpeech {
ctx := &context.Context{
Config: cfg,
AccessTokenContextHandle: dialog.NewAccessToken(cfg),
}
return &AISpeech{ctx: ctx}
}

// GetContext get Context.
func (a *AISpeech) GetContext() *context.Context {
return a.ctx
}

// SetAccessTokenHandle 自定义 access_token 获取方式.
func (a *AISpeech) SetAccessTokenHandle(accessTokenHandle credential.AccessTokenHandle) {
a.ctx.AccessTokenContextHandle = credential.AccessTokenCompatibleHandle{
AccessTokenHandle: accessTokenHandle,
}
}

// SetAccessTokenContextHandle 自定义 access_token 获取方式.
func (a *AISpeech) SetAccessTokenContextHandle(accessTokenContextHandle credential.AccessTokenContextHandle) {
a.ctx.AccessTokenContextHandle = accessTokenContextHandle
}

// GetAccessToken 获取 access token.
func (a *AISpeech) GetAccessToken() (string, error) {
return a.ctx.GetAccessToken()
}

// GetAccessTokenContext 获取 access token.
func (a *AISpeech) GetAccessTokenContext(ctx stdcontext.Context) (string, error) {
return a.ctx.GetAccessTokenContext(ctx)
}

// GetDialog 获取对话平台 API.
func (a *AISpeech) GetDialog() *dialog.Dialog {
if a.dialog == nil {
a.dialog = dialog.NewDialog(a.ctx)
}
return a.dialog
}
61 changes: 61 additions & 0 deletions aispeech/aispeech_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package aispeech

import (
stdcontext "context"
"testing"

"github.com/silenceper/wechat/v2/aispeech/config"
"github.com/silenceper/wechat/v2/cache"
)

type staticAccessToken struct {
token string
}

func (s staticAccessToken) GetAccessToken() (string, error) {
return s.token, nil
}

type staticAccessTokenContext struct {
token string
}

type contextTokenKey struct{}

func (s staticAccessTokenContext) GetAccessToken() (string, error) {
return s.GetAccessTokenContext(stdcontext.Background())
}

func (s staticAccessTokenContext) GetAccessTokenContext(ctx stdcontext.Context) (string, error) {
if v := ctx.Value(contextTokenKey{}); v != nil {
return v.(string), nil
}
return s.token, nil
}

func TestSetAccessTokenHandle(t *testing.T) {
ai := NewAISpeech(&config.Config{Cache: cache.NewMemory()})
ai.SetAccessTokenHandle(staticAccessToken{token: "custom-token"})

token, err := ai.GetAccessToken()
if err != nil {
t.Fatalf("GetAccessToken error: %v", err)
}
if token != "custom-token" {
t.Fatalf("bad token: %s", token)
}
}

func TestSetAccessTokenContextHandle(t *testing.T) {
ai := NewAISpeech(&config.Config{Cache: cache.NewMemory()})
ai.SetAccessTokenContextHandle(staticAccessTokenContext{token: "custom-token"})

ctx := stdcontext.WithValue(stdcontext.Background(), contextTokenKey{}, "context-token")
token, err := ai.GetAccessTokenContext(ctx)
if err != nil {
t.Fatalf("GetAccessTokenContext error: %v", err)
}
if token != "context-token" {
t.Fatalf("bad token: %s", token)
}
}
23 changes: 23 additions & 0 deletions aispeech/config/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package config

import "github.com/silenceper/wechat/v2/cache"

const defaultBaseURL = "https://openaiapi.weixin.qq.com"

// Config for 微信智能对话.
type Config struct {
AppID string `json:"app_id"`
Token string `json:"token"`
AESKey string `json:"aes_key"`
Account string `json:"account"`
BaseURL string `json:"base_url"`
Cache cache.Cache
}

// GetBaseURL returns the API base URL.
func (cfg *Config) GetBaseURL() string {
if cfg.BaseURL == "" {
return defaultBaseURL
}
return cfg.BaseURL
}
15 changes: 15 additions & 0 deletions aispeech/config/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package config

import "testing"

func TestGetBaseURL(t *testing.T) {
cfg := &Config{}
if cfg.GetBaseURL() != defaultBaseURL {
t.Fatalf("bad default base url: %s", cfg.GetBaseURL())
}

cfg.BaseURL = "http://example.com"
if cfg.GetBaseURL() != "http://example.com" {
t.Fatalf("bad custom base url: %s", cfg.GetBaseURL())
}
}
12 changes: 12 additions & 0 deletions aispeech/context/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package context

import (
"github.com/silenceper/wechat/v2/aispeech/config"
"github.com/silenceper/wechat/v2/credential"
)

// Context struct
type Context struct {
*config.Config
credential.AccessTokenContextHandle
}
116 changes: 116 additions & 0 deletions aispeech/dialog/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package dialog

import (
stdcontext "context"
"encoding/json"

"github.com/silenceper/wechat/v2/aispeech/encryptor"
)

// GetAccessToken 获取 access token.
func (d *Dialog) GetAccessToken() (string, error) {
return d.Context.GetAccessToken()
}

// GetAccessTokenContext 获取 access token.
func (d *Dialog) GetAccessTokenContext(ctx stdcontext.Context) (string, error) {
return d.Context.GetAccessTokenContext(ctx)
}

// ImportJSON 简单问答导入.
func (d *Dialog) ImportJSON(req *ImportJSONRequest) (*ImportJSONResponse, error) {
return d.ImportJSONContext(stdcontext.Background(), req)
}

// ImportJSONContext 简单问答导入.
func (d *Dialog) ImportJSONContext(ctx stdcontext.Context, req *ImportJSONRequest) (*ImportJSONResponse, error) {
var res ImportJSONResponse
requestID, err := d.postJSON(ctx, importJSONPath, req, &res, "AISpeechImportJSON")
res.RequestID = requestID
return &res, err
}

// Publish 发布机器人.
func (d *Dialog) Publish() (*PublishResponse, error) {
return d.PublishContext(stdcontext.Background())
}

// PublishContext 发布机器人.
func (d *Dialog) PublishContext(ctx stdcontext.Context) (*PublishResponse, error) {
var res PublishResponse
requestID, err := d.postEmpty(ctx, publishPath, &res, "AISpeechPublish")
res.RequestID = requestID
return &res, err
}

// GetEffectiveProgress 查询机器人发布进度.
func (d *Dialog) GetEffectiveProgress(req *EffectiveProgressRequest) (*EffectiveProgressResponse, error) {
return d.GetEffectiveProgressContext(stdcontext.Background(), req)
}

// GetEffectiveProgressContext 查询机器人发布进度.
func (d *Dialog) GetEffectiveProgressContext(ctx stdcontext.Context, req *EffectiveProgressRequest) (*EffectiveProgressResponse, error) {
var res EffectiveProgressResponse
requestID, err := d.postJSON(ctx, effectiveProgressPath, req, &res, "AISpeechGetEffectiveProgress")
res.RequestID = requestID
return &res, err
}

// FetchAsync 查询异步任务.
func (d *Dialog) FetchAsync(req *FetchAsyncRequest) (*FetchAsyncResponse, error) {
return d.FetchAsyncContext(stdcontext.Background(), req)
}

// FetchAsyncContext 查询异步任务.
func (d *Dialog) FetchAsyncContext(ctx stdcontext.Context, req *FetchAsyncRequest) (*FetchAsyncResponse, error) {
var res FetchAsyncResponse
requestID, err := d.postJSON(ctx, fetchAsyncPath, req, &res, "AISpeechFetchAsync")
res.RequestID = requestID
return &res, err
}

// Query 调用智能对话.
func (d *Dialog) Query(req *QueryRequest) (*QueryResponse, error) {
return d.QueryContext(stdcontext.Background(), req)
}

// QueryContext 调用智能对话.
func (d *Dialog) QueryContext(ctx stdcontext.Context, req *QueryRequest) (*QueryResponse, error) {
body, err := json.Marshal(req)
if err != nil {
return nil, err
}
encryptedBody, err := encryptor.Encrypt(d.AESKey, body)
if err != nil {
return nil, err
}
accessToken, err := d.GetAccessTokenContext(ctx)
if err != nil {
return nil, err
}
if accessToken == "" {
return nil, errEmptyAccessToken
}
response, err := post(ctx, d.Config, queryPath, []byte(encryptedBody), "text/plain", accessToken, "")
if err != nil {
return nil, err
}
plainResponse := response
if !json.Valid(response) {
plainResponse, err = encryptor.Decrypt(d.AESKey, string(response))
if err != nil {
return nil, err
}
}

var res QueryResponse
requestID, err := decodeResponse(plainResponse, &res, "AISpeechQuery")
if err != nil {
return nil, err
}
res.RequestID = requestID
if json.Valid([]byte(res.Answer)) {
res.RawAnswer = json.RawMessage(res.Answer)
}
return &res, nil
}
Loading
Loading