فهرست منبع

bug fix 学习进度不显示learning

roo00 6 سال پیش
والد
کامیت
b5a1bba835

+ 2 - 1
o2server/x_query_assemble_designer/src/main/java/com/x/query/assemble/designer/jaxrs/neural/ActionGetModel.java

@@ -8,6 +8,7 @@ import com.x.base.core.project.bean.WrapCopierFactory;
 import com.x.base.core.project.exception.ExceptionEntityNotExist;
 import com.x.base.core.project.http.ActionResult;
 import com.x.base.core.project.http.EffectivePerson;
+import com.x.base.core.project.tools.ListTools;
 import com.x.query.assemble.designer.Business;
 import com.x.query.core.entity.neural.Model;
 
@@ -31,7 +32,7 @@ class ActionGetModel extends BaseAction {
 		private static final long serialVersionUID = -6541538280679110474L;
 
 		static WrapCopier<Model, Wo> copier = WrapCopierFactory.wo(Model.class, Wo.class, null,
-				JpaObject.FieldsInvisible);
+				ListTools.toList(JpaObject.FieldsInvisible, Model.nnet_FIELDNAME, Model.intermediateNnet_FIELDNAME));
 
 	}
 

+ 4 - 2
o2server/x_query_core_entity/src/main/java/com/x/query/core/entity/neural/Model.java

@@ -120,6 +120,8 @@ public class Model extends SliceJpaObject {
 
 	public static final String STATUS_GENERATING = "generating";
 	public static final String STATUS_LEARNING = "learning";
+	public static final String STATUS_COMPLETED = "completed";
+	public static final String STATUS_EXCESSIVE = "excessive";
 
 	public static final String DATATYPE_PROCESSPLATFORM = "processPlatform";
 	public static final String DATATYPE_CMS = "cms";
@@ -260,7 +262,7 @@ public class Model extends SliceJpaObject {
 	@PersistentCollection(fetch = FetchType.EAGER)
 	@ContainerTable(name = TABLE + ContainerTableNameMiddle + processList_FIELDNAME, joinIndex = @Index(name = TABLE
 			+ IndexNameMiddle + processList_FIELDNAME + JoinIndexNameSuffix))
-	@OrderColumn(name =  ORDERCOLUMNCOLUMN)
+	@OrderColumn(name = ORDERCOLUMNCOLUMN)
 	@ElementColumn(length = length_255B, name = ColumnNamePrefix + processList_FIELDNAME)
 	@ElementIndex(name = TABLE + IndexNameMiddle + processList_FIELDNAME + ElementIndexNameSuffix)
 	@CheckPersist(allowEmpty = true)
@@ -271,7 +273,7 @@ public class Model extends SliceJpaObject {
 	@PersistentCollection(fetch = FetchType.EAGER)
 	@ContainerTable(name = TABLE + ContainerTableNameMiddle + applicationList_FIELDNAME, joinIndex = @Index(name = TABLE
 			+ IndexNameMiddle + applicationList_FIELDNAME + JoinIndexNameSuffix))
-	@OrderColumn(name =  ORDERCOLUMNCOLUMN)
+	@OrderColumn(name = ORDERCOLUMNCOLUMN)
 	@ElementColumn(length = length_255B, name = ColumnNamePrefix + applicationList_FIELDNAME)
 	@ElementIndex(name = TABLE + IndexNameMiddle + applicationList_FIELDNAME + ElementIndexNameSuffix)
 	@CheckPersist(allowEmpty = true)

+ 9 - 11
o2server/x_query_service_processing/src/main/java/com/x/query/service/processing/jaxrs/neural/Generate.java

@@ -111,8 +111,8 @@ public class Generate {
 			if (StringUtils.equals(Model.STATUS_GENERATING, model.getStatus())) {
 				throw new ExceptionGenerate(model.getName());
 			}
-			final Double validationRate = (MapTools.getDouble(model.getPropertyMap(),
-					Model.PROPERTY_MLP_VALIDATIONRATE, Model.DEFAULT_MLP_VALIDATIONRATE));
+			final Double validationRate = (MapTools.getDouble(model.getPropertyMap(), Model.PROPERTY_MLP_VALIDATIONRATE,
+					Model.DEFAULT_MLP_VALIDATIONRATE));
 			List<String> bundles = this.listBundle(business, model);
 			if (ListTools.isEmpty(bundles)) {
 				throw new ExceptionBundleEmpty(model.getName());
@@ -164,8 +164,7 @@ public class Generate {
 					outValues.clear();
 					this.convert(business, converter, scriptHelper, lph, model, workCompleted, inValues, outValues);
 					if ((!inValues.isEmpty()) && (!outValues.isEmpty())) {
-						this.createValidationEntry(business, model, workCompleted, inBag, outBag, inValues,
-								outValues);
+						this.createValidationEntry(business, model, workCompleted, inBag, outBag, inValues, outValues);
 						testEntryCount++;
 					}
 				}
@@ -181,7 +180,8 @@ public class Generate {
 			model = this.refreshModel(business, modelId);
 			emc.beginTransaction(Model.class);
 			model.setStatus("");
-			model.setEntryCount(bundles.size());
+			model.setEntryCount(total);
+			model.setEffectiveEntryCount(bundles.size());
 			model.setGeneratingPercent(100);
 			model.setLearnEntryCount(learnEntryCount);
 			model.setValidationEntryCount(testEntryCount);
@@ -329,8 +329,8 @@ public class Generate {
 		Root<WorkCompleted> root = cq.from(WorkCompleted.class);
 		Predicate p = cb.conjunction();
 		if (ListTools.isNotEmpty(model.getApplicationList())) {
-			p = cb.and(p, cb.isMember(root.get(WorkCompleted.application_FIELDNAME),
-					cb.literal(model.getApplicationList())));
+			p = cb.and(p,
+					cb.isMember(root.get(WorkCompleted.application_FIELDNAME), cb.literal(model.getApplicationList())));
 		}
 		if (ListTools.isNotEmpty(model.getProcessList())) {
 			p = cb.and(p, cb.isMember(root.get(WorkCompleted.process_FIELDNAME), cb.literal(model.getProcessList())));
@@ -373,8 +373,7 @@ public class Generate {
 	}
 
 	private Long cleanInText(Business business, String modelId) throws Exception {
-		List<String> ids = business.entityManagerContainer().idsEqual(InText.class, InText.model_FIELDNAME,
-				modelId);
+		List<String> ids = business.entityManagerContainer().idsEqual(InText.class, InText.model_FIELDNAME, modelId);
 		Long count = 0L;
 		for (List<String> os : ListTools.batch(ids, 2000)) {
 			business.entityManagerContainer().beginTransaction(InText.class);
@@ -385,8 +384,7 @@ public class Generate {
 	}
 
 	private Long cleanOutText(Business business, String modelId) throws Exception {
-		List<String> ids = business.entityManagerContainer().idsEqual(OutText.class, OutText.model_FIELDNAME,
-				modelId);
+		List<String> ids = business.entityManagerContainer().idsEqual(OutText.class, OutText.model_FIELDNAME, modelId);
 		Long count = 0L;
 		for (List<String> os : ListTools.batch(ids, 2000)) {
 			business.entityManagerContainer().beginTransaction(OutText.class);

+ 15 - 1
o2server/x_query_service_processing/src/main/java/com/x/query/service/processing/jaxrs/neural/Learn.java

@@ -162,7 +162,6 @@ public class Learn {
 				model = refreshmodel(business, modelId);
 				emc.beginTransaction(Model.class);
 				model.setNnet(this.encode(neuralNetwork));
-				model.setStatus("");
 				model.setInValueCount(inTextBag.size());
 				model.setOutValueCount(outTextBag.size());
 				emc.commit();
@@ -170,10 +169,20 @@ public class Learn {
 				inTextBag.saveToInValue(business);
 				this.cleanOutValue(business, model);
 				outTextBag.saveToOutValue(business);
+				if (logger.isDebug()) {
+					File file = new File(Config.dir_local_temp(), model.getId() + ".nnet");
+					neuralNetwork.save(file.getAbsolutePath());
+					logger.debug("save nnet file to ", file.getAbsolutePath());
+				}
 				if (neuralNetwork.getLearningRule().getErrorFunction().getTotalError() > maxError) {
 					logger.print("神经网络多层感知机 ({}) 学习失败, 耗时: {}, 总误差: {}, 未能达到预期值: {}.", modelName,
 							stamp.consumingMilliseconds(),
 							neuralNetwork.getLearningRule().getErrorFunction().getTotalError(), maxError);
+					emc.beginTransaction(Model.class);
+					model.setValidationMeanSquareError(
+							neuralNetwork.getLearningRule().getErrorFunction().getTotalError());
+					model.setStatus(Model.STATUS_EXCESSIVE);
+					emc.commit();
 				} else {
 					logger.print("神经网络多层感知机 ({}) 学习完成.", modelName);
 					if (!validationSet.isEmpty()) {
@@ -183,9 +192,14 @@ public class Learn {
 						model = refreshmodel(business, modelId);
 						emc.beginTransaction(Model.class);
 						model.setValidationMeanSquareError(evaluationResult.getMeanSquareError());
+						model.setStatus(Model.STATUS_COMPLETED);
 						emc.commit();
 						logger.print("神经网络多层感知机 ({}) 测试数据数量: {}, 测试结果集标准方差: {}.", modelName, validationSet.size(),
 								evaluationResult.getMeanSquareError());
+					} else {
+						emc.beginTransaction(Model.class);
+						model.setStatus(Model.STATUS_COMPLETED);
+						emc.commit();
 					}
 					// logger.info("##############################################################################");
 					// logger.info("MeanSquare Error: " +