266 lines
6.4 KiB
JavaScript
266 lines
6.4 KiB
JavaScript
|
|
const { ModelVersion, TrainingData } = require('../models');
|
||
|
|
const logger = require('../utils/logger');
|
||
|
|
|
||
|
|
const getModelStatus = async (req, res) => {
|
||
|
|
try {
|
||
|
|
const activeModels = await ModelVersion.findAll({
|
||
|
|
where: { is_active: true },
|
||
|
|
order: [['created_at', 'DESC']]
|
||
|
|
});
|
||
|
|
|
||
|
|
const modelStatus = {
|
||
|
|
MODEL1: null,
|
||
|
|
QUERYMODEL: null
|
||
|
|
};
|
||
|
|
|
||
|
|
activeModels.forEach(model => {
|
||
|
|
if (model.model_type === 'MODEL1') {
|
||
|
|
modelStatus.MODEL1 = {
|
||
|
|
id: model.id,
|
||
|
|
version: model.version,
|
||
|
|
deployment_status: model.deployment_status,
|
||
|
|
performance_metrics: model.performance_metrics,
|
||
|
|
last_updated: model.updated_at
|
||
|
|
};
|
||
|
|
} else if (model.model_type === 'QUERYMODEL') {
|
||
|
|
modelStatus.QUERYMODEL = {
|
||
|
|
id: model.id,
|
||
|
|
version: model.version,
|
||
|
|
deployment_status: model.deployment_status,
|
||
|
|
performance_metrics: model.performance_metrics,
|
||
|
|
last_updated: model.updated_at
|
||
|
|
};
|
||
|
|
}
|
||
|
|
});
|
||
|
|
|
||
|
|
res.json({
|
||
|
|
success: true,
|
||
|
|
data: { modelStatus }
|
||
|
|
});
|
||
|
|
} catch (error) {
|
||
|
|
logger.error('Get model status error:', error);
|
||
|
|
res.status(500).json({
|
||
|
|
success: false,
|
||
|
|
error: 'Internal server error'
|
||
|
|
});
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
const getModelVersions = async (req, res) => {
|
||
|
|
try {
|
||
|
|
const { modelType, page = 1, limit = 10 } = req.query;
|
||
|
|
|
||
|
|
const whereClause = {};
|
||
|
|
if (modelType) whereClause.model_type = modelType;
|
||
|
|
|
||
|
|
const models = await ModelVersion.findAndCountAll({
|
||
|
|
where: whereClause,
|
||
|
|
order: [['created_at', 'DESC']],
|
||
|
|
limit: parseInt(limit),
|
||
|
|
offset: (parseInt(page) - 1) * parseInt(limit)
|
||
|
|
});
|
||
|
|
|
||
|
|
res.json({
|
||
|
|
success: true,
|
||
|
|
data: {
|
||
|
|
models: models.rows,
|
||
|
|
pagination: {
|
||
|
|
page: parseInt(page),
|
||
|
|
limit: parseInt(limit),
|
||
|
|
total: models.count,
|
||
|
|
pages: Math.ceil(models.count / parseInt(limit))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
});
|
||
|
|
} catch (error) {
|
||
|
|
logger.error('Get model versions error:', error);
|
||
|
|
res.status(500).json({
|
||
|
|
success: false,
|
||
|
|
error: 'Internal server error'
|
||
|
|
});
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
const createModelVersion = async (req, res) => {
|
||
|
|
try {
|
||
|
|
const {
|
||
|
|
modelName,
|
||
|
|
modelType,
|
||
|
|
baseModel,
|
||
|
|
fineTuningMethod,
|
||
|
|
hyperparameters = {}
|
||
|
|
} = req.body;
|
||
|
|
|
||
|
|
if (!modelName || !modelType || !baseModel) {
|
||
|
|
return res.status(400).json({
|
||
|
|
success: false,
|
||
|
|
error: 'Model name, type, and base model are required'
|
||
|
|
});
|
||
|
|
}
|
||
|
|
|
||
|
|
const validModelTypes = ['MODEL1', 'QUERYMODEL'];
|
||
|
|
if (!validModelTypes.includes(modelType)) {
|
||
|
|
return res.status(400).json({
|
||
|
|
success: false,
|
||
|
|
error: 'Invalid model type'
|
||
|
|
});
|
||
|
|
}
|
||
|
|
|
||
|
|
// Deactivate current active model of same type
|
||
|
|
await ModelVersion.update(
|
||
|
|
{ is_active: false },
|
||
|
|
{ where: { model_type: modelType, is_active: true } }
|
||
|
|
);
|
||
|
|
|
||
|
|
const model = await ModelVersion.create({
|
||
|
|
model_name: modelName,
|
||
|
|
version: `v${Date.now()}`,
|
||
|
|
model_type: modelType,
|
||
|
|
base_model: baseModel,
|
||
|
|
fine_tuning_method: fineTuningMethod,
|
||
|
|
hyperparameters,
|
||
|
|
deployment_status: 'training'
|
||
|
|
});
|
||
|
|
|
||
|
|
logger.info(`New model version created: ${model.id}`);
|
||
|
|
|
||
|
|
res.status(201).json({
|
||
|
|
success: true,
|
||
|
|
data: { model }
|
||
|
|
});
|
||
|
|
} catch (error) {
|
||
|
|
logger.error('Create model version error:', error);
|
||
|
|
res.status(500).json({
|
||
|
|
success: false,
|
||
|
|
error: 'Internal server error'
|
||
|
|
});
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
const updateModelVersion = async (req, res) => {
|
||
|
|
try {
|
||
|
|
const { modelId } = req.params;
|
||
|
|
const {
|
||
|
|
deploymentStatus,
|
||
|
|
performanceMetrics,
|
||
|
|
modelPath,
|
||
|
|
trainingLog
|
||
|
|
} = req.body;
|
||
|
|
|
||
|
|
const model = await ModelVersion.findByPk(modelId);
|
||
|
|
|
||
|
|
if (!model) {
|
||
|
|
return res.status(404).json({
|
||
|
|
success: false,
|
||
|
|
error: 'Model version not found'
|
||
|
|
});
|
||
|
|
}
|
||
|
|
|
||
|
|
const updateData = {};
|
||
|
|
if (deploymentStatus) updateData.deployment_status = deploymentStatus;
|
||
|
|
if (performanceMetrics) updateData.performance_metrics = performanceMetrics;
|
||
|
|
if (modelPath) updateData.model_path = modelPath;
|
||
|
|
if (trainingLog) updateData.training_log = trainingLog;
|
||
|
|
|
||
|
|
await model.update(updateData);
|
||
|
|
|
||
|
|
logger.info(`Model version updated: ${modelId}`);
|
||
|
|
|
||
|
|
res.json({
|
||
|
|
success: true,
|
||
|
|
data: { model }
|
||
|
|
});
|
||
|
|
} catch (error) {
|
||
|
|
logger.error('Update model version error:', error);
|
||
|
|
res.status(500).json({
|
||
|
|
success: false,
|
||
|
|
error: 'Internal server error'
|
||
|
|
});
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
const activateModel = async (req, res) => {
|
||
|
|
try {
|
||
|
|
const { modelId } = req.params;
|
||
|
|
|
||
|
|
const model = await ModelVersion.findByPk(modelId);
|
||
|
|
|
||
|
|
if (!model) {
|
||
|
|
return res.status(404).json({
|
||
|
|
success: false,
|
||
|
|
error: 'Model version not found'
|
||
|
|
});
|
||
|
|
}
|
||
|
|
|
||
|
|
// Deactivate other models of same type
|
||
|
|
await ModelVersion.update(
|
||
|
|
{ is_active: false },
|
||
|
|
{ where: { model_type: model.model_type, is_active: true } }
|
||
|
|
);
|
||
|
|
|
||
|
|
// Activate this model
|
||
|
|
await model.update({
|
||
|
|
is_active: true,
|
||
|
|
deployment_status: 'deployed'
|
||
|
|
});
|
||
|
|
|
||
|
|
logger.info(`Model activated: ${modelId}`);
|
||
|
|
|
||
|
|
res.json({
|
||
|
|
success: true,
|
||
|
|
data: { model }
|
||
|
|
});
|
||
|
|
} catch (error) {
|
||
|
|
logger.error('Activate model error:', error);
|
||
|
|
res.status(500).json({
|
||
|
|
success: false,
|
||
|
|
error: 'Internal server error'
|
||
|
|
});
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
const getTrainingData = async (req, res) => {
|
||
|
|
try {
|
||
|
|
const { modelVersionId, dataType, page = 1, limit = 10 } = req.query;
|
||
|
|
|
||
|
|
const whereClause = {};
|
||
|
|
if (modelVersionId) whereClause.model_version_id = modelVersionId;
|
||
|
|
if (dataType) whereClause.data_type = dataType;
|
||
|
|
|
||
|
|
const trainingData = await TrainingData.findAndCountAll({
|
||
|
|
where: whereClause,
|
||
|
|
order: [['created_at', 'DESC']],
|
||
|
|
limit: parseInt(limit),
|
||
|
|
offset: (parseInt(page) - 1) * parseInt(limit)
|
||
|
|
});
|
||
|
|
|
||
|
|
res.json({
|
||
|
|
success: true,
|
||
|
|
data: {
|
||
|
|
trainingData: trainingData.rows,
|
||
|
|
pagination: {
|
||
|
|
page: parseInt(page),
|
||
|
|
limit: parseInt(limit),
|
||
|
|
total: trainingData.count,
|
||
|
|
pages: Math.ceil(trainingData.count / parseInt(limit))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
});
|
||
|
|
} catch (error) {
|
||
|
|
logger.error('Get training data error:', error);
|
||
|
|
res.status(500).json({
|
||
|
|
success: false,
|
||
|
|
error: 'Internal server error'
|
||
|
|
});
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
module.exports = {
|
||
|
|
getModelStatus,
|
||
|
|
getModelVersions,
|
||
|
|
createModelVersion,
|
||
|
|
updateModelVersion,
|
||
|
|
activateModel,
|
||
|
|
getTrainingData
|
||
|
|
};
|