'vowpal wabbit java: get raw predictions

I am using Java API of vowpal wabbit to get predictions. I need raw prediction (same as -r output.txt) but I couldn't find any such method in VWMulticlassLearner class. I am using below arg to train my model in python via cmd -

vw -f model_filepath -c --cache_file cache_filepath -k --csoaa 40 -b 24 -q cd -q .... -q n: --ignore a --ignore x

and we are using below code in Java to get predictions -

VWLearners.create("-i ./data/train.model  -t --quiet"); // VWMulticlassLearner
VWLearners.create("-i ./data/train.model  -t --quiet --csoaa_ldf=mc --loss_function=logistic --probabilities"); //VWProbLearner

None of the classes has any method which returns raw prediction.

I want the same prediction as below -

$ echo ' .. sample string .. ' | vw -i data/train.model -t -r test -p /dev/stdout
creating quadratic features for pairs: cd ce cu cw de du dw eu ew uw n:
ignoring namespaces beginning with: a x
only testing
predictions = /dev/stdout
raw predictions = test
Num weight bits = 24
learning rate = 0.5
initial_t = 0
power_t = 0.5
using no cache
Reading datafile =
num sources = 1
average  since         example        example  current  current  current
loss     last          counter         weight    label  predict features
39
0.000000 0.000000            1            1.0    known       39      171

finished run
number of examples per pass = 1
passes used = 1
weighted example sum = 1.000000
weighted label sum = 0.000000
average loss = 0.000000
total feature number = 171

$ cat test
0:1.05645 1:0.83437 2:-0.210798 3:-2.81048 4:-4.47558 5:-4.45883 6:-3.65177 7:-3.71191 8:-2.96008 9:-2.82846 10:-2.31816 11:0.925984 12:3.28547 13:5.20375 14:6.34244 15:6.13525 16:1.65726 17:1.22801 18:1.35034 19:3.27091 20:2.94066 21:-0.0276409 22:0.391437 23:1.267 24:-0.689573 25:0.0171876 26:3.12935 27:3.95045 28:3.86978 29:1.18468 30:0.0921049 31:0.436564 32:0.98946 33:1.00963 34:-0.265355 35:-3.02128 36:-2.52846 37:-2.8066 38:-3.50639 39:-4.6184

How can I get values that are in file test in Java as a method response? I don't want to read the file to get a response in Java which will be slow.



Solution 1:[1]

I ended up using one of the abandoned PR. Here is my working git patch file -

diff --git a/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc b/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc
index 6b51c4d30..f3ccb6621 100644
--- a/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc
+++ b/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc
@@ -11,3 +11,17 @@ JNIEXPORT jint JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_predict(JNI
 JNIEXPORT jint JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_predictMultiline(JNIEnv *env, jobject obj, jobjectArray example_strings, jboolean learn, jlong vwPtr)
 { return base_predict<jint>(env, example_strings, learn, vwPtr, multiclass_predictor);
 }
+
+jfloatArray multiclass_raw_predictor(example* vec, JNIEnv *env){
+  size_t num_values = vec->l.cs.costs.size();
+  jfloatArray j_labels = env->NewFloatArray(num_values);
+  for (int i=0 ; i<num_values; i++) {
+    jfloat f[] = { vec->l.cs.costs[i].partial_prediction };
+    env->SetFloatArrayRegion(j_labels, i, 1, (float*)f);
+   }
+   return j_labels;
+ }
+
+JNIEXPORT jfloatArray JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_rawPredict(JNIEnv *env, jobject obj, jstring example_string, jboolean learn, jlong vwPtr){
+return base_predict<jfloatArray>(env, example_string, learn, vwPtr, multiclass_raw_predictor);
+}
diff --git a/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.h b/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.h
index 05204d53e..5610704fa 100644
--- a/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.h
+++ b/java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.h
@@ -24,6 +24,15 @@ JNIEXPORT jint JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_predict
 JNIEXPORT jint JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_predictMultiline
 (JNIEnv *, jobject, jobjectArray, jboolean, jlong);
 
+/*
+ * Class:     vowpalWabbit_learner_VWMulticlassLearner
+ * Method:    rawPredict
+ * Signature: ([Ljava/lang/String;ZJ)I
+ */
+JNIEXPORT jfloatArray JNICALL Java_vowpalWabbit_learner_VWMulticlassLearner_rawPredict
+  (JNIEnv *, jobject, jstring, jboolean, jlong);
+
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/java/src/main/java/vowpalWabbit/learner/VWMulticlassLearner.java b/java/src/main/java/vowpalWabbit/learner/VWMulticlassLearner.java
index b506cfb25..bb3156351 100644
--- a/java/src/main/java/vowpalWabbit/learner/VWMulticlassLearner.java
+++ b/java/src/main/java/vowpalWabbit/learner/VWMulticlassLearner.java
@@ -13,4 +13,25 @@ final public class VWMulticlassLearner extends VWIntLearner {
 
     @Override
     protected native int predictMultiline(String[] example, boolean learn, long nativePointer);
+
+    protected native float[] rawPredict(String example, boolean learn, long nativePointer);
+
+    /**
+     * Get raw prediction output.
+     *
+     * @param example a single vw example string
+     * @return Raw prediction
+     */
+
+    public float[] rawPredict(final String example) {
+        lock.lock();
+        try {
+            if (isOpen()) {
+                return rawPredict(example, false, nativePointer);
+            }
+            throw new IllegalStateException("Already closed.");
+        } finally {
+            lock.unlock();
+        }
+    }
 }

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Nishant Kumar