Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/api/handler/coze/workflow_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"context"
"errors"
"fmt"
"github.com/alicebob/miniredis/v2"
"net/http"
"os"
"reflect"
Expand All @@ -31,6 +30,8 @@ import (
"testing"
"time"

"github.com/alicebob/miniredis/v2"

"github.com/bytedance/mockey"
"github.com/cloudwego/eino/callbacks"
model2 "github.com/cloudwego/eino/components/model"
Expand Down
44 changes: 26 additions & 18 deletions backend/application/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ import (
variablesImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/variables"
workflowImpl "github.com/coze-dev/coze-studio/backend/crossdomain/impl/workflow"
"github.com/coze-dev/coze-studio/backend/infra/contract/eventbus"
"github.com/coze-dev/coze-studio/backend/infra/impl/chatmodel"
"github.com/coze-dev/coze-studio/backend/infra/impl/checkpoint"
implEventbus "github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
)
Expand Down Expand Up @@ -191,7 +192,9 @@ func initPrimaryServices(ctx context.Context, basicServices *basicServices) (*pr

memorySVC := memory.InitService(basicServices.toMemoryServiceComponents())

knowledgeSVC, err := knowledge.InitService(basicServices.toKnowledgeServiceComponents(memorySVC))
knowledgeSVC, err := knowledge.InitService(ctx,
basicServices.toKnowledgeServiceComponents(memorySVC),
basicServices.eventbus.resourceEventBus)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -256,14 +259,18 @@ func (b *basicServices) toPluginServiceComponents() *plugin.ServiceComponents {
func (b *basicServices) toKnowledgeServiceComponents(memoryService *memory.MemoryApplicationServices) *knowledge.ServiceComponents {
return &knowledge.ServiceComponents{
DB: b.infra.DB,
IDGenSVC: b.infra.IDGenSVC,
Storage: b.infra.TOSClient,
IDGen: b.infra.IDGenSVC,
RDB: memoryService.RDBDomainSVC,
Producer: b.infra.KnowledgeEventProducer,
SearchStoreManagers: b.infra.SearchStoreManagers,
EventBus: b.eventbus.resourceEventBus,
CacheCli: b.infra.CacheCli,
ParseManager: b.infra.ParserManager,
Storage: b.infra.TOSClient,
Rewriter: b.infra.Rewriter,
Reranker: b.infra.Reranker,
NL2Sql: b.infra.NL2SQL,
OCR: b.infra.OCR,
ParserManager: b.infra.ParserManager,
CacheCli: b.infra.CacheCli,
ModelFactory: chatmodel.NewDefaultFactory(),
}
}

Expand All @@ -280,18 +287,19 @@ func (b *basicServices) toMemoryServiceComponents() *memory.ServiceComponents {

func (b *basicServices) toWorkflowServiceComponents(pluginSVC *plugin.PluginApplicationService, memorySVC *memory.MemoryApplicationServices, knowledgeSVC *knowledge.KnowledgeApplicationService) *workflow.ServiceComponents {
return &workflow.ServiceComponents{
IDGen: b.infra.IDGenSVC,
DB: b.infra.DB,
Cache: b.infra.CacheCli,
Tos: b.infra.TOSClient,
ImageX: b.infra.ImageXClient,
DatabaseDomainSVC: memorySVC.DatabaseDomainSVC,
VariablesDomainSVC: memorySVC.VariablesDomainSVC,
PluginDomainSVC: pluginSVC.DomainSVC,
KnowledgeDomainSVC: knowledgeSVC.DomainSVC,
DomainNotifier: b.eventbus.resourceEventBus,
CPStore: checkpoint.NewRedisStore(b.infra.CacheCli),
CodeRunner: b.infra.CodeRunner,
IDGen: b.infra.IDGenSVC,
DB: b.infra.DB,
Cache: b.infra.CacheCli,
Tos: b.infra.TOSClient,
ImageX: b.infra.ImageXClient,
DatabaseDomainSVC: memorySVC.DatabaseDomainSVC,
VariablesDomainSVC: memorySVC.VariablesDomainSVC,
PluginDomainSVC: pluginSVC.DomainSVC,
KnowledgeDomainSVC: knowledgeSVC.DomainSVC,
DomainNotifier: b.eventbus.resourceEventBus,
CPStore: checkpoint.NewRedisStore(b.infra.CacheCli),
CodeRunner: b.infra.CodeRunner,
WorkflowBuildInChatModel: b.infra.WorkflowBuildInChatModel,
}
}

Expand Down
147 changes: 132 additions & 15 deletions backend/application/base/appinfra/app_infra.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ package appinfra

import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
Expand All @@ -29,26 +31,32 @@ import (

"github.com/cloudwego/eino-ext/components/embedding/ollama"
"github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
"github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/volcengine/volc-sdk-golang/service/visual"

"github.com/coze-dev/coze-studio/backend/application/internal"
"github.com/coze-dev/coze-studio/backend/infra/contract/cache"
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
"github.com/coze-dev/coze-studio/backend/infra/contract/coderunner"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/nl2sql"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/ocr"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/parser"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/rerank"
"github.com/coze-dev/coze-studio/backend/infra/contract/document/searchstore"
"github.com/coze-dev/coze-studio/backend/infra/contract/embedding"
"github.com/coze-dev/coze-studio/backend/infra/contract/imagex"
"github.com/coze-dev/coze-studio/backend/infra/contract/messages2query"
"github.com/coze-dev/coze-studio/backend/infra/contract/modelmgr"
"github.com/coze-dev/coze-studio/backend/infra/impl/cache/redis"
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/direct"
"github.com/coze-dev/coze-studio/backend/infra/impl/coderunner/sandbox"
builtinNL2SQL "github.com/coze-dev/coze-studio/backend/infra/impl/document/nl2sql/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/ppocr"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/ocr/veocr"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/ppstructure"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb"
Expand All @@ -59,6 +67,7 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/impl/eventbus"
"github.com/coze-dev/coze-studio/backend/infra/impl/idgen"
"github.com/coze-dev/coze-studio/backend/infra/impl/imagex/veimagex"
builtinM2Q "github.com/coze-dev/coze-studio/backend/infra/impl/messages2query/builtin"
"github.com/coze-dev/coze-studio/backend/infra/impl/mysql"
"github.com/coze-dev/coze-studio/backend/infra/impl/storage"
"github.com/coze-dev/coze-studio/backend/pkg/lang/conv"
Expand All @@ -68,19 +77,24 @@ import (
)

type AppDependencies struct {
DB *gorm.DB
CacheCli cache.Cmdable
IDGenSVC idgen.IDGenerator
ESClient es.Client
ImageXClient imagex.ImageX
TOSClient storage.Storage
ResourceEventProducer eventbus.Producer
AppEventProducer eventbus.Producer
ModelMgr modelmgr.Manager
CodeRunner coderunner.Runner
OCR ocr.OCR
ParserManager parser.Manager
SearchStoreManagers []searchstore.Manager
DB *gorm.DB
CacheCli cache.Cmdable
IDGenSVC idgen.IDGenerator
ESClient es.Client
ImageXClient imagex.ImageX
TOSClient storage.Storage
ResourceEventProducer eventbus.Producer
AppEventProducer eventbus.Producer
KnowledgeEventProducer eventbus.Producer
ModelMgr modelmgr.Manager
CodeRunner coderunner.Runner
OCR ocr.OCR
ParserManager parser.Manager
SearchStoreManagers []searchstore.Manager
Reranker rerank.Reranker
Rewriter messages2query.MessagesToQuery
NL2SQL nl2sql.NL2SQL
WorkflowBuildInChatModel chatmodel.BaseChatModel
}

func Init(ctx context.Context) (*AppDependencies, error) {
Expand Down Expand Up @@ -124,6 +138,23 @@ func Init(ctx context.Context) (*AppDependencies, error) {
return nil, fmt.Errorf("init app event producer failed, err=%w", err)
}

deps.KnowledgeEventProducer, err = initKnowledgeEventBusProducer()
if err != nil {
return nil, fmt.Errorf("init knowledge event bus producer failed, err=%w", err)
}

deps.Reranker = rrf.NewRRFReranker(0)

deps.Rewriter, err = initRewriter(ctx)
if err != nil {
return nil, fmt.Errorf("init rewriter failed, err=%w", err)
}

deps.NL2SQL, err = initNL2SQL(ctx)
if err != nil {
return nil, fmt.Errorf("init nl2sql failed, err=%w", err)
}

deps.ModelMgr, err = initModelMgr()
if err != nil {
return nil, fmt.Errorf("init model manager failed, err=%w", err)
Expand All @@ -133,11 +164,21 @@ func Init(ctx context.Context) (*AppDependencies, error) {

deps.OCR = initOCR()

imageAnnotationModel, _, err := internal.GetBuiltinChatModel(ctx, "IA_")
imageAnnotationModel, _, err := getBuiltinChatModel(ctx, "IA_")
if err != nil {
return nil, fmt.Errorf("get builtin chat model failed, err=%w", err)
}

var ok bool
deps.WorkflowBuildInChatModel, ok, err = getBuiltinChatModel(ctx, "WKR_")
if err != nil {
return nil, fmt.Errorf("get workflow builtin chat model failed, err=%w", err)
}

if !ok {
logs.CtxWarnf(ctx, "workflow builtin chat model for knowledge recall not configured")
}

deps.ParserManager, err = initParserManager(deps.TOSClient, deps.OCR, imageAnnotationModel)
if err != nil {
return nil, fmt.Errorf("init parser manager failed, err=%w", err)
Expand All @@ -164,6 +205,71 @@ func initSearchStoreManagers(ctx context.Context, es es.Client) ([]searchstore.M
return []searchstore.Manager{esSearchstoreManager, mgr}, nil
}

func initRewriter(ctx context.Context) (messages2query.MessagesToQuery, error) {
rewriterChatModel, _, err := getBuiltinChatModel(ctx, "M2Q_")
if err != nil {
return nil, err
}

filePath := filepath.Join(getWorkingDirectory(), "resources/conf/prompt/messages_to_query_template_jinja2.json")
rewriterTemplate, err := readJinja2PromptTemplate(filePath)
if err != nil {
return nil, err
}

rewriter, err := builtinM2Q.NewMessagesToQuery(ctx, rewriterChatModel, rewriterTemplate)
if err != nil {
return nil, err
}

return rewriter, nil
}

func getWorkingDirectory() string {
root, err := os.Getwd()
if err != nil {
logs.Warnf("[InitConfig] Failed to get current working directory: %v", err)
root = os.Getenv("PWD")
}
return root
}

func readJinja2PromptTemplate(jsonFilePath string) (prompt.ChatTemplate, error) {
b, err := os.ReadFile(jsonFilePath)
if err != nil {
return nil, err
}
var m2qMessages []*schema.Message
if err = json.Unmarshal(b, &m2qMessages); err != nil {
return nil, err
}
tpl := make([]schema.MessagesTemplate, len(m2qMessages))
for i := range m2qMessages {
tpl[i] = m2qMessages[i]
}
return prompt.FromMessages(schema.Jinja2, tpl...), nil
}

func initNL2SQL(ctx context.Context) (nl2sql.NL2SQL, error) {
n2sChatModel, _, err := getBuiltinChatModel(ctx, "NL2SQL_")
if err != nil {
return nil, err
}

filePath := filepath.Join(getWorkingDirectory(), "resources/conf/prompt/nl2sql_template_jinja2.json")
n2sTemplate, err := readJinja2PromptTemplate(filePath)
if err != nil {
return nil, err
}

n2s, err := builtinNL2SQL.NewNL2SQL(ctx, n2sChatModel, n2sTemplate)
if err != nil {
return nil, err
}

return n2s, nil
}

func initImageX(ctx context.Context) (imagex.ImageX, error) {
uploadComponentType := os.Getenv(consts.FileUploadComponentType)
if uploadComponentType != consts.FileUploadComponentTypeImagex {
Expand Down Expand Up @@ -204,6 +310,17 @@ func initAppEventProducer() (eventbus.Producer, error) {
return appEventProducer, nil
}

func initKnowledgeEventBusProducer() (eventbus.Producer, error) {
nameServer := os.Getenv(consts.MQServer)

knowledgeProducer, err := eventbus.NewProducer(nameServer, consts.RMQTopicKnowledge, consts.RMQConsumeGroupKnowledge, 2)
if err != nil {
return nil, fmt.Errorf("init knowledge producer failed, err=%w", err)
}

return knowledgeProducer, nil
}

func initCodeRunner() coderunner.Runner {
switch typ := os.Getenv(consts.CodeRunnerType); typ {
case "sandbox":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package internal
package appinfra

import (
"context"
Expand All @@ -33,7 +33,7 @@ import (
"github.com/coze-dev/coze-studio/backend/infra/contract/chatmodel"
)

func GetBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.BaseChatModel, configured bool, err error) {
func getBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.BaseChatModel, configured bool, err error) {
getEnv := func(key string) string {
if val := os.Getenv(envPrefix + key); val != "" {
return val
Expand Down Expand Up @@ -99,7 +99,7 @@ func GetBuiltinChatModel(ctx context.Context, envPrefix string) (bcm chatmodel.B
}

if err != nil {
return nil, false, fmt.Errorf("knowledge init openai chat mode failed, %w", err)
return nil, false, fmt.Errorf("builtin %s chat model init failed, %w", envPrefix, err)
}
if bcm != nil {
configured = true
Expand Down
Loading
Loading